diff --git a/pkg/command/entities.py b/pkg/command/entities.py index 27cb596..8697551 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel): """命令返回值 """ - text: typing.Optional[str] + text: typing.Optional[str] = None """文本 """ - image: typing.Optional[mirai.Image] + image: typing.Optional[mirai.Image] = None + """弃用""" + + image_url: typing.Optional[str] = None + """图片链接 + """ error: typing.Optional[errors.CommandError]= None """错误 diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py index ca7e404..ee46c7d 100644 --- a/pkg/command/operators/default.py +++ b/pkg/command/operators/default.py @@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator): content = "" for msg in prompt.messages: - content += f" {msg.role}: {msg.content}" + content += f" {msg.readable_str()}\n" reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" @@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator): context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) - else: - prompt_name = context.crt_params[0] - - try: - prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) - if prompt is None: - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) - else: - context.session.use_prompt_name = prompt.name - yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") - except Exception as e: - traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) + else: + prompt_name = context.crt_params[0] + + try: + prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) + if prompt is None: + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) + else: + context.session.use_prompt_name = prompt.name + yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py index 8e3a523..e7a14c8 100644 --- a/pkg/command/operators/last.py +++ b/pkg/command/operators/last.py @@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator): context.session.using_conversation = context.session.conversations[index-1] time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") - yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}") return else: yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py index 258e0ee..ff90d4d 100644 --- a/pkg/command/operators/list.py +++ b/pkg/command/operators/list.py @@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator): using_conv_index = index if index >= page * record_per_page and index < (page + 1) * record_per_page: - content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n" + content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n" index += 1 if content == '': @@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator): if context.session.using_conversation is None: content += "\n当前处于新会话" else: - content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}" + content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}" yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}") diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py index 754bfa5..362bc78 100644 --- a/pkg/config/impls/json.py +++ b/pkg/config/impls/json.py @@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile): else: raise ValueError("template_file_name or template_data must be provided") - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: if not self.exists(): await self.create() @@ -39,9 +39,11 @@ class JSONConfigFile(file_model.ConfigFile): with open(self.config_file_name, "r", encoding="utf-8") as f: cfg = json.load(f) - for key in self.template_data: - if key not in cfg: - cfg[key] = self.template_data[key] + if completion: + + for key in self.template_data: + if key not in cfg: + cfg[key] = self.template_data[key] return cfg diff --git a/pkg/config/impls/pymodule.py b/pkg/config/impls/pymodule.py index ceeebad..67e5867 100644 --- a/pkg/config/impls/pymodule.py +++ b/pkg/config/impls/pymodule.py @@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile): async def create(self): shutil.copyfile(self.template_file_name, self.config_file_name) - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module = importlib.import_module(module_name) @@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile): cfg[key] = getattr(module, key) # 从模板模块文件中进行补全 - module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] - module = importlib.import_module(module_name) + if completion: + module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] + module = importlib.import_module(module_name) - for key in dir(module): - if key.startswith('__'): - continue + for key in dir(module): + if key.startswith('__'): + continue - if not isinstance(getattr(module, key), allowed_types): - continue + if not isinstance(getattr(module, key), allowed_types): + continue - if key not in cfg: - cfg[key] = getattr(module, key) + if key not in cfg: + cfg[key] = getattr(module, key) return cfg diff --git a/pkg/config/manager.py b/pkg/config/manager.py index f9e93c8..7983407 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -20,8 +20,8 @@ class ConfigManager: self.file = cfg_file self.data = {} - async def load_config(self): - self.data = await self.file.load() + async def load_config(self, completion: bool=True): + self.data = await self.file.load(completion=completion) async def dump_config(self): await self.file.save(self.data) @@ -30,7 +30,7 @@ class ConfigManager: self.file.save_sync(self.data) -async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: +async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager: """加载Python模块配置文件""" cfg_inst = pymodule.PythonModuleConfigFile( config_name, @@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con ) cfg_mgr = ConfigManager(cfg_inst) - await cfg_mgr.load_config() + await cfg_mgr.load_config(completion=completion) return cfg_mgr -async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager: +async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: """加载JSON配置文件""" cfg_inst = json_file.JSONConfigFile( config_name, @@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d ) cfg_mgr = ConfigManager(cfg_inst) - await cfg_mgr.load_config() + await cfg_mgr.load_config(completion=completion) return cfg_mgr \ No newline at end of file diff --git a/pkg/config/migrations/m006_vision_config.py b/pkg/config/migrations/m006_vision_config.py new file mode 100644 index 0000000..8084611 --- /dev/null +++ b/pkg/config/migrations/m006_vision_config.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("vision-config", 6) +class VisionConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return "enable-vision" not in self.ap.provider_cfg.data + + async def run(self): + """执行迁移""" + if "enable-vision" not in self.ap.provider_cfg.data: + self.ap.provider_cfg.data["enable-vision"] = False + + await self.ap.provider_cfg.dump_config() diff --git a/pkg/config/model.py b/pkg/config/model.py index d209093..153123e 100644 --- a/pkg/config/model.py +++ b/pkg/config/model.py @@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: pass @abc.abstractmethod diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 4adf132..da0ae07 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -14,6 +14,7 @@ required_deps = { "yaml": "pyyaml", "aiohttp": "aiohttp", "psutil": "psutil", + "async_lru": "async-lru", } diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 2e7d0b1..30b983a 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -70,7 +70,7 @@ class Query(pydantic.BaseModel): resp_messages: typing.Optional[list[llm_entities.Message]] = [] """由Process阶段生成的回复消息对象列表""" - resp_message_chain: typing.Optional[mirai.MessageChain] = None + resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None """回复消息链,从resp_messages包装而得""" class Config: diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index d536582..39ecc02 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -15,7 +15,6 @@ from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr from ...platform import manager as im_mgr - @stage.stage_class("BuildAppStage") class BuildAppStage(stage.BootingStage): """构建应用阶段 @@ -83,7 +82,6 @@ class BuildAppStage(stage.BootingStage): llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst - im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() ap.platform_mgr = im_mgr_inst @@ -92,5 +90,6 @@ class BuildAppStage(stage.BootingStage): await stage_mgr.initialize() ap.stage_mgr = stage_mgr + ctrl = controller.Controller(ap) ap.ctrl = ctrl diff --git a/pkg/core/stages/load_config.py b/pkg/core/stages/load_config.py index 9e61c1c..cb6e1ed 100644 --- a/pkg/core/stages/load_config.py +++ b/pkg/core/stages/load_config.py @@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage): async def run(self, ap: app.Application): """启动 """ - ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json") - ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json") - ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") - ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") - ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") + ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False) + ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False) + ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False) + ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False) + ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False) ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json") await ap.plugin_setting_meta.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index cef3b42..4d5b8d8 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -4,7 +4,7 @@ import importlib from .. import stage, app from ...config import migration -from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion +from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion, m006_vision_config from ...config.migrations import m005_deepseek_cfg_completion diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 95a7cff..9c04138 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr @stage.stage_class('BanSessionCheckStage') class BanSessionCheckStage(stage.PipelineStage): - """访问控制处理阶段""" + """访问控制处理阶段 + + 仅检查query中群号或个人号是否在访问控制列表中。 + """ async def initialize(self): pass diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 21b6c25..a669e31 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -9,12 +9,24 @@ from ...core import entities as core_entities from ...config import manager as cfg_mgr from . import filter as filter_model, entities as filter_entities from .filters import cntignore, banwords, baiduexamine +from ...provider import entities as llm_entities @stage.stage_class('PostContentFilterStage') @stage.stage_class('PreContentFilterStage') class ContentFilterStage(stage.PipelineStage): - """内容过滤阶段""" + """内容过滤阶段 + + 前置: + 检查消息是否符合规则,不符合则拦截。 + 改写: + message_chain + + 后置: + 检查AI回复消息是否符合规则,可能进行改写,不符合则拦截。 + 改写: + query.resp_messages + """ filter_chain: list[filter_model.ContentFilter] @@ -130,6 +142,21 @@ class ContentFilterStage(stage.PipelineStage): """处理 """ if stage_inst_name == 'PreContentFilterStage': + + contain_non_text = False + + for me in query.message_chain: + if not isinstance(me, mirai.Plain): + contain_non_text = True + break + + if contain_non_text: + self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。") + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + return await self._pre_process( str(query.message_chain).strip(), query diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index 8ff581f..af60a59 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -4,6 +4,8 @@ import enum import pydantic +from ...provider import entities as llm_entities + class ResultLevel(enum.Enum): """结果等级""" @@ -38,7 +40,7 @@ class FilterResult(pydantic.BaseModel): """ replacement: str - """替换后的消息 + """替换后的文本消息 内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。 若没有修改内容,也需要返回原消息。 diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 8b34e0c..8eceb87 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -5,6 +5,7 @@ import typing from ...core import app from . import entities +from ...provider import entities as llm_entities preregistered_filters: list[typing.Type[ContentFilter]] = [] @@ -63,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str) -> entities.FilterResult: + async def process(self, message: str=None, image_url=None) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 @@ -71,6 +72,7 @@ class ContentFilter(metaclass=abc.ABCMeta): Args: message (str): 需要检查的内容 + image_url (str): 要检查的图片的 URL Returns: entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档 diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 5cd7dcf..1430c2e 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr @filter_model.filter_class("ban-word-filter") class BanWordFilter(filter_model.ContentFilter): - """根据内容禁言""" + """根据内容过滤""" async def initialize(self): pass diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 28c2814..756df44 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr @stage.stage_class("LongTextProcessStage") class LongTextProcessStage(stage.PipelineStage): """长消息处理阶段 + + 改写: + - resp_message_chain """ strategy_impl: strategy.LongTextStrategy @@ -59,15 +62,15 @@ class LongTextProcessStage(stage.PipelineStage): # 检查是否包含非 Plain 组件 contains_non_plain = False - for msg in query.resp_message_chain: + for msg in query.resp_message_chain[-1]: if not isinstance(msg, Plain): contains_non_plain = True break if contains_non_plain: self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") - elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: - query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) + elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']: + query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index bd48c48..ba7f999 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -43,7 +43,7 @@ class QueryPool: message_event=message_event, message_chain=message_chain, resp_messages=[], - resp_message_chain=None, + resp_message_chain=[], adapter=adapter ) self.queries.append(query) diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index cedc030..ebe4d31 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -1,5 +1,7 @@ from __future__ import annotations +import mirai + from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...provider import entities as llm_entities @@ -9,6 +11,16 @@ from ...plugin import events @stage.stage_class("PreProcessor") class PreProcessor(stage.PipelineStage): """请求预处理阶段 + + 签出会话、prompt、上文、模型、内容函数。 + + 改写: + - session + - prompt + - messages + - user_message + - use_model + - use_funcs """ async def process( @@ -27,21 +39,42 @@ class PreProcessor(stage.PipelineStage): query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - query.user_message = llm_entities.Message( - role='user', - content=str(query.message_chain).strip() - ) - query.use_model = conversation.use_model - query.use_funcs = conversation.use_funcs + query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None + + # 检查vision是否启用,没启用就删除所有图片 + if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported: + for msg in query.messages: + if isinstance(msg.content, list): + for me in msg.content: + if me.type == 'image_url': + msg.content.remove(me) + + content_list = [] + + for me in query.message_chain: + if isinstance(me, mirai.Plain): + content_list.append( + llm_entities.ContentElement.from_text(me.text) + ) + elif isinstance(me, mirai.Image): + if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported: + if me.url is not None: + content_list.append( + llm_entities.ContentElement.from_image_url(str(me.url)) + ) + + query.user_message = llm_entities.Message( # TODO 适配多模态输入 + role='user', + content=content_list + ) # =========== 触发事件 PromptPreProcessing - session = query.session event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PromptPreProcessing( - session_name=f'{session.launcher_type.value}_{session.launcher_id}', + session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', default_prompt=query.prompt.messages, prompt=query.messages, query=query diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index f38ee34..26f73b6 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing import time import traceback +import json import mirai @@ -70,17 +71,13 @@ class ChatMessageHandler(handler.MessageHandler): mirai.Plain(event_ctx.event.alter) ]) - query.messages.append( - query.user_message - ) - text_length = 0 start_time = time.time() try: - async for result in query.use_model.requester.request(query): + async for result in self.runner(query): query.resp_messages.append(result) self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') @@ -92,6 +89,9 @@ class ChatMessageHandler(handler.MessageHandler): result_type=entities.ResultType.CONTINUE, new_query=query ) + + query.session.using_conversation.messages.append(query.user_message) + query.session.using_conversation.messages.extend(query.resp_messages) except Exception as e: self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}') @@ -104,8 +104,6 @@ class ChatMessageHandler(handler.MessageHandler): debug_notice=traceback.format_exc() ) finally: - query.session.using_conversation.messages.append(query.user_message) - query.session.using_conversation.messages.extend(query.resp_messages) await self.ap.ctr_mgr.usage.post_query_record( session_type=query.session.launcher_type.value, @@ -115,4 +113,65 @@ class ChatMessageHandler(handler.MessageHandler): model_name=query.use_model.name, response_seconds=int(time.time() - start_time), retry_times=-1, - ) \ No newline at end of file + ) + + async def runner( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """执行一个请求处理过程中的LLM接口请求、函数调用的循环 + + 这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器 + """ + await query.use_model.requester.preprocess(query) + + pending_tool_calls = [] + + req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + + # 首次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) + + # 持续请求,只要还有待处理的工具调用就继续处理调用 + while pending_tool_calls: + for tool_call in pending_tool_calls: + try: + func = tool_call.function + + parameters = json.loads(func.arguments) + + func_ret = await self.ap.tool_mgr.execute_func_call( + query, func.name, parameters + ) + + msg = llm_entities.Message( + role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + ) + + yield msg + + req_messages.append(msg) + except Exception as e: + # 工具调用出错,添加一个报错信息到 req_messages + err_msg = llm_entities.Message( + role="tool", content=f"err: {e}", tool_call_id=tool_call.id + ) + + yield err_msg + + req_messages.append(err_msg) + + # 处理完所有调用,再次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 7179fd3..02ff269 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -80,9 +80,6 @@ class CommandHandler(handler.MessageHandler): session=session ): if ret.error is not None: - # query.resp_message_chain = mirai.MessageChain([ - # mirai.Plain(str(ret.error)) - # ]) query.resp_messages.append( llm_entities.Message( role='command', @@ -96,18 +93,28 @@ class CommandHandler(handler.MessageHandler): result_type=entities.ResultType.CONTINUE, new_query=query ) - elif ret.text is not None: - # query.resp_message_chain = mirai.MessageChain([ - # mirai.Plain(ret.text) - # ]) + elif ret.text is not None or ret.image_url is not None: + + content: list[llm_entities.ContentElement]= [] + + if ret.text is not None: + content.append( + llm_entities.ContentElement.from_text(ret.text) + ) + + if ret.image_url is not None: + content.append( + llm_entities.ContentElement.from_image_url(ret.image_url) + ) + query.resp_messages.append( llm_entities.Message( role='command', - content=ret.text, + content=content, ) ) - self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}') + self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}') yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index ddf8809..e58d15e 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr @stage.stage_class("MessageProcessor") class Processor(stage.PipelineStage): - """请求实际处理阶段""" + """请求实际处理阶段 + + 通过命令处理器和聊天处理器处理消息。 + + 改写: + - resp_messages + """ cmd_handler: handler.MessageHandler diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index 2622247..cd39b85 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -11,7 +11,10 @@ from ...core import entities as core_entities @stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy") class RateLimit(stage.PipelineStage): - """限速器控制阶段""" + """限速器控制阶段 + + 不改写query,只检查是否需要限速。 + """ algo: algo.ReteLimitAlgo diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 36a7329..d3af14e 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -31,7 +31,7 @@ class SendResponseBackStage(stage.PipelineStage): await self.ap.platform_mgr.send( query.message_event, - query.resp_message_chain, + query.resp_message_chain[-1], adapter=query.adapter ) diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index d795d05..fce0c4e 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr @stage.stage_class("GroupRespondRuleCheckStage") class GroupRespondRuleCheckStage(stage.PipelineStage): """群组响应规则检查器 + + 仅检查群消息是否符合规则。 """ rule_matchers: list[rule.GroupRespondRule] + """检查器实例""" async def initialize(self): """初始化检查器 @@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - if query.launcher_type.value != 'group': + if query.launcher_type.value != 'group': # 只处理群消息 return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index 23c7897..46957aa 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -17,17 +17,17 @@ from .ratelimit import ratelimit # 请求处理阶段顺序 stage_order = [ - "GroupRespondRuleCheckStage", - "BanSessionCheckStage", - "PreContentFilterStage", - "PreProcessor", - "RequireRateLimitOccupancy", - "MessageProcessor", - "ReleaseRateLimitOccupancy", - "PostContentFilterStage", - "ResponseWrapper", - "LongTextProcessStage", - "SendResponseBackStage", + "GroupRespondRuleCheckStage", # 群响应规则检查 + "BanSessionCheckStage", # 封禁会话检查 + "PreContentFilterStage", # 内容过滤前置阶段 + "PreProcessor", # 预处理器 + "RequireRateLimitOccupancy", # 请求速率限制占用 + "MessageProcessor", # 处理器 + "ReleaseRateLimitOccupancy", # 释放速率限制占用 + "PostContentFilterStage", # 内容过滤后置阶段 + "ResponseWrapper", # 响应包装器 + "LongTextProcessStage", # 长文本处理 + "SendResponseBackStage", # 发送响应 ] diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index a500d7c..acf0549 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -14,6 +14,13 @@ from ...plugin import events @stage.stage_class("ResponseWrapper") class ResponseWrapper(stage.PipelineStage): + """回复包装阶段 + + 把回复的 message 包装成人类识读的形式。 + + 改写: + - resp_message_chain + """ async def initialize(self): pass @@ -27,17 +34,19 @@ class ResponseWrapper(stage.PipelineStage): """ if query.resp_messages[-1].role == 'command': - query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) + # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] ')) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) elif query.resp_messages[-1].role == 'plugin': - if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): - query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content) - else: - query.resp_message_chain = query.resp_messages[-1].content + # if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): + # query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) + # else: + # query.resp_message_chain.append(query.resp_messages[-1].content) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -52,7 +61,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = '' if result.content is not None: # 有内容 - reply_text = result.content + reply_text = str(result.get_content_mirai_message_chain()) # ============= 触发插件事件 =============== event_ctx = await self.ap.plugin_mgr.emit_event( @@ -76,11 +85,11 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) else: - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + query.resp_message_chain.append(result.get_content_mirai_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -93,7 +102,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = f'调用函数 {".".join(function_names)}...' - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) if self.ap.platform_cfg.data['track-function-calls']: @@ -119,13 +128,13 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) else: - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query - ) \ No newline at end of file + ) diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index a30d4e3..3b87f5c 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -21,6 +21,39 @@ class ToolCall(pydantic.BaseModel): function: FunctionCall +class ImageURLContentObject(pydantic.BaseModel): + url: str + + def __str__(self): + return self.url[:128] + ('...' if len(self.url) > 128 else '') + + +class ContentElement(pydantic.BaseModel): + + type: str + """内容类型""" + + text: typing.Optional[str] = None + + image_url: typing.Optional[ImageURLContentObject] = None + + def __str__(self): + if self.type == 'text': + return self.text + elif self.type == 'image_url': + return f'[图片]({self.image_url})' + else: + return '未知内容' + + @classmethod + def from_text(cls, text: str): + return cls(type='text', text=text) + + @classmethod + def from_image_url(cls, image_url: str): + return cls(type='image_url', image_url=ImageURLContentObject(url=image_url)) + + class Message(pydantic.BaseModel): """消息""" @@ -30,12 +63,9 @@ class Message(pydantic.BaseModel): name: typing.Optional[str] = None """名称,仅函数调用返回时设置""" - content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None + content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None """内容""" - function_call: typing.Optional[FunctionCall] = None - """函数调用,不再受支持,请使用tool_calls""" - tool_calls: typing.Optional[list[ToolCall]] = None """工具调用""" @@ -43,10 +73,38 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return str(self.content) - elif self.function_call is not None: - return f'{self.function_call.name}({self.function_call.arguments})' + return str(self.role) + ": " + str(self.get_content_mirai_message_chain()) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: return '未知消息' + + def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None: + """将内容转换为 Mirai MessageChain 对象 + + Args: + prefix_text (str): 首个文字组件的前缀文本 + """ + + if self.content is None: + return None + elif isinstance(self.content, str): + return mirai.MessageChain([mirai.Plain(prefix_text+self.content)]) + elif isinstance(self.content, list): + mc = [] + for ce in self.content: + if ce.type == 'text': + mc.append(mirai.Plain(ce.text)) + elif ce.type == 'image': + mc.append(mirai.Image(url=ce.image_url)) + + # 找第一个文字组件 + if prefix_text: + for i, c in enumerate(mc): + if isinstance(c, mirai.Plain): + mc[i] = mirai.Plain(prefix_text+c.text) + break + else: + mc.insert(0, mirai.Plain(prefix_text)) + + return mirai.MessageChain(mc) diff --git a/pkg/provider/modelmgr/api.py b/pkg/provider/modelmgr/api.py index 63021be..930cf9e 100644 --- a/pkg/provider/modelmgr/api.py +++ b/pkg/provider/modelmgr/api.py @@ -6,6 +6,8 @@ import typing from ...core import app from ...core import entities as core_entities from .. import entities as llm_entities +from . import entities as modelmgr_entities +from ..tools import entities as tools_entities preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] @@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): async def initialize(self): pass - @abc.abstractmethod - async def request( + async def preprocess( self, query: core_entities.Query, - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求API + ): + """预处理 + + 在这里处理特定API对Query对象的兼容性问题。 + """ + pass - 对话前文可以从 query 对象中获取。 - 可以多次yield消息对象。 + @abc.abstractmethod + async def call( + self, + model: modelmgr_entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + """调用API Args: - query (core_entities.Query): 本次请求的上下文对象 + model (modelmgr_entities.LLMModelInfo): 使用的模型信息 + messages (typing.List[llm_entities.Message]): 消息对象列表 + funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. - Yields: - pkg.provider.entities.Message: 返回消息对象 + Returns: + llm_entities.Message: 返回消息对象 """ - raise NotImplementedError + pass diff --git a/pkg/provider/modelmgr/apis/anthropicmsgs.py b/pkg/provider/modelmgr/apis/anthropicmsgs.py index 42bd385..ee2c51a 100644 --- a/pkg/provider/modelmgr/apis/anthropicmsgs.py +++ b/pkg/provider/modelmgr/apis/anthropicmsgs.py @@ -27,47 +27,60 @@ class AnthropicMessages(api.LLMAPIRequester): proxies=self.ap.proxy_mgr.get_forward_proxies() ) - async def request( + async def call( self, - query: core_entities.Query, - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - self.client.api_key = query.use_model.token_mgr.get_token() + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = model.token_mgr.get_token() args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() - args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name + args["model"] = model.name if model.model_name is None else model.model_name - req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 - m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" - ] + [m.dict(exclude_none=True) for m in query.messages] + # 处理消息 - # 删除所有 role=system & content='' 的消息 - req_messages = [ - m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "") - ] + # system + system_role_message = None - # 检查是否有 role=system 的消息,若有,改为 role=user,并在后面加一个 role=assistant 的消息 - system_role_index = [] - for i, m in enumerate(req_messages): - if m["role"] == "system": - system_role_index.append(i) - m["role"] = "user" + for i, m in enumerate(messages): + if m.role == "system": + system_role_message = m - if system_role_index: - for i in system_role_index[::-1]: - req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."}) + messages.pop(i) + break - # 忽略掉空消息,用户可能发送空消息,而上层未过滤 - req_messages = [ - m for m in req_messages if m["content"].strip() != "" - ] + if isinstance(system_role_message, llm_entities.Message) \ + and isinstance(system_role_message.content, str): + args['system'] = system_role_message.content + + # 其他消息 + # req_messages = [ + # m.dict(exclude_none=True) for m in messages \ + # if (isinstance(m.content, str) and m.content.strip() != "") \ + # or (isinstance(m.content, list) and ) + # ] + # 暂时不支持vision,仅保留纯文字的content + req_messages = [] + + for m in messages: + if isinstance(m.content, str) and m.content.strip() != "": + req_messages.append(m.dict(exclude_none=True)) + elif isinstance(m.content, list): + # 删除m.content中的type!=text的元素 + m.content = [ + c for c in m.content if c.get("type") == "text" + ] + + if len(m.content) > 0: + req_messages.append(m.dict(exclude_none=True)) args["messages"] = req_messages try: - resp = await self.client.messages.create(**args) - yield llm_entities.Message( + return llm_entities.Message( content=resp.content[0].text, role=resp.role ) @@ -79,4 +92,4 @@ class AnthropicMessages(api.LLMAPIRequester): if 'model: ' in str(e): raise errors.RequesterError(f'模型无效: {e.message}') else: - raise errors.RequesterError(f'请求地址无效: {e.message}') \ No newline at end of file + raise errors.RequesterError(f'请求地址无效: {e.message}') diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index e3901de..028b208 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -3,16 +3,20 @@ from __future__ import annotations import asyncio import typing import json +import base64 from typing import AsyncGenerator import openai import openai.types.chat.chat_completion as chat_completion import httpx +import aiohttp +import async_lru from .. import api, entities, errors from ....core import entities as core_entities, app from ... import entities as llm_entities from ...tools import entities as tools_entities +from ....utils import image @api.requester_class("openai-chat-completions") @@ -43,7 +47,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester): self, args: dict, ) -> chat_completion.ChatCompletion: - self.ap.logger.debug(f"req chat_completion with args {args}") return await self.client.chat.completions.create(**args) async def _make_msg( @@ -67,14 +70,22 @@ class OpenAIChatCompletions(api.LLMAPIRequester): args = self.requester_cfg['args'].copy() args["model"] = use_model.name if use_model.model_name is None else use_model.model_name - if use_model.tool_call_supported: + if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: args["tools"] = tools # 设置此次请求中的messages - messages = req_messages + messages = req_messages.copy() + + # 检查vision + for msg in messages: + if 'content' in msg and isinstance(msg["content"], list): + for me in msg["content"]: + if me["type"] == "image_url": + me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url']) + args["messages"] = messages # 发送请求 @@ -84,73 +95,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester): message = await self._make_msg(resp) return message - - async def _request( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求""" - - pending_tool_calls = [] - + + async def call( + self, + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 - m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" - ] + [m.dict(exclude_none=True) for m in query.messages] + m.dict(exclude_none=True) for m in messages + ] - # req_messages.append({"role": "user", "content": str(query.message_chain)}) - - # 首次请求 - msg = await self._closure(req_messages, query.use_model, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg.dict(exclude_none=True)) - - # 持续请求,只要还有待处理的工具调用就继续处理调用 - while pending_tool_calls: - for tool_call in pending_tool_calls: - try: - func = tool_call.function - - parameters = json.loads(func.arguments) - - func_ret = await self.ap.tool_mgr.execute_func_call( - query, func.name, parameters - ) - - msg = llm_entities.Message( - role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id - ) - - yield msg - - req_messages.append(msg.dict(exclude_none=True)) - except Exception as e: - # 出错,添加一个报错信息到 req_messages - err_msg = llm_entities.Message( - role="tool", content=f"err: {e}", tool_call_id=tool_call.id - ) - - yield err_msg - - req_messages.append( - err_msg.dict(exclude_none=True) - ) - - # 处理完所有调用,继续请求 - msg = await self._closure(req_messages, query.use_model, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg.dict(exclude_none=True)) - - async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]: try: - async for msg in self._request(query): - yield msg + return await self._closure(req_messages, model, funcs) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: @@ -163,6 +120,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester): except openai.NotFoundError as e: raise errors.RequesterError(f'请求路径错误: {e.message}') except openai.RateLimitError as e: - raise errors.RequesterError(f'请求过于频繁: {e.message}') + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + @async_lru.alru_cache(maxsize=128) + async def get_base64_str( + self, + original_url: str, + ) -> str: + + base64_image = await image.qq_image_url_to_base64(original_url) + + return f"data:image/jpeg;base64,{base64_image}" diff --git a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py index dd8ddc6..4edc2cb 100644 --- a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py @@ -3,7 +3,10 @@ from __future__ import annotations from ....core import app from . import chatcmpl -from .. import api +from .. import api, entities, errors +from ....core import entities as core_entities, app +from ... import entities as llm_entities +from ...tools import entities as tools_entities @api.requester_class("deepseek-chat-completions") @@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): def __init__(self, ap: app.Application): self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions'] - self.ap = ap \ No newline at end of file + self.ap = ap + + async def _closure( + self, + req_messages: list[dict], + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = use_model.token_mgr.get_token() + + args = self.requester_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + if use_funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + + # deepseek 不支持多模态,把content都转换成纯文字 + for m in messages: + if 'content' in m and isinstance(m["content"], list): + m["content"] = " ".join([c["text"] for c in m["content"]]) + + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message \ No newline at end of file diff --git a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py index cb9fd93..2f299b8 100644 --- a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py @@ -3,7 +3,10 @@ from __future__ import annotations from ....core import app from . import chatcmpl -from .. import api +from .. import api, entities, errors +from ....core import entities as core_entities, app +from ... import entities as llm_entities +from ...tools import entities as tools_entities @api.requester_class("moonshot-chat-completions") @@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): def __init__(self, ap: app.Application): self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] self.ap = ap + + async def _closure( + self, + req_messages: list[dict], + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = use_model.token_mgr.get_token() + + args = self.requester_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + if use_funcs: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + + # deepseek 不支持多模态,把content都转换成纯文字 + for m in messages: + if 'content' in m and isinstance(m["content"], list): + m["content"] = " ".join([c["text"] for c in m["content"]]) + + # 删除空的 + messages = [m for m in messages if m["content"].strip() != ""] + + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message \ No newline at end of file diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index 277f125..79cb544 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -21,5 +21,7 @@ class LLMModelInfo(pydantic.BaseModel): tool_call_supported: typing.Optional[bool] = False + vision_supported: typing.Optional[bool] = False + class Config: arbitrary_types_allowed = True diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 3fffd78..79e467a 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -37,7 +37,7 @@ class ModelManager: raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") async def initialize(self): - + # 初始化token_mgr, requester for k, v in self.ap.provider_cfg.data['keys'].items(): self.token_mgrs[k] = token.TokenManager(k, v) @@ -83,7 +83,8 @@ class ModelManager: model_name=None, token_mgr=self.token_mgrs[model['token_mgr']], requester=self.requesters[model['requester']], - tool_call_supported=model['tool_call_supported'] + tool_call_supported=model['tool_call_supported'], + vision_supported=model['vision_supported'] ) break @@ -95,13 +96,15 @@ class ModelManager: token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported) + vision_supported = model.get('vision_supported', default_model_info.vision_supported) model_info = entities.LLMModelInfo( name=model['name'], model_name=model_name, token_mgr=token_mgr, requester=requester, - tool_call_supported=tool_call_supported + tool_call_supported=tool_call_supported, + vision_supported=vision_supported ) self.model_list.append(model_info) diff --git a/pkg/utils/image.py b/pkg/utils/image.py new file mode 100644 index 0000000..34acc2f --- /dev/null +++ b/pkg/utils/image.py @@ -0,0 +1,41 @@ +import base64 +import typing +from urllib.parse import urlparse, parse_qs +import ssl + +import aiohttp + + +async def qq_image_url_to_base64( + image_url: str +) -> str: + """将QQ图片URL转为base64 + + Args: + image_url (str): QQ图片URL + + Returns: + str: base64编码 + """ + parsed = urlparse(image_url) + query = parse_qs(parsed.query) + + # Flatten the query dictionary + query = {k: v[0] for k, v in query.items()} + + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + async with aiohttp.ClientSession(trust_env=False) as session: + async with session.get( + f"http://{parsed.netloc}{parsed.path}", + params=query, + ssl=ssl_context + ) as resp: + resp.raise_for_status() # 检查HTTP错误 + file_bytes = await resp.read() + + base64_str = base64.b64encode(file_bytes).decode() + + return base64_str diff --git a/requirements.txt b/requirements.txt index f04bdc9..44bc285 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ aiohttp pydantic websockets urllib3 -psutil \ No newline at end of file +psutil +async-lru \ No newline at end of file diff --git a/templates/metadata/llm-models.json b/templates/metadata/llm-models.json index 9787223..235eea7 100644 --- a/templates/metadata/llm-models.json +++ b/templates/metadata/llm-models.json @@ -4,23 +4,73 @@ "name": "default", "requester": "openai-chat-completions", "token_mgr": "openai", - "tool_call_supported": false + "tool_call_supported": false, + "vision_supported": false + }, + { + "name": "gpt-3.5-turbo-0125", + "tool_call_supported": true, + "vision_supported": false }, { "name": "gpt-3.5-turbo", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": false }, { - "name": "gpt-4", - "tool_call_supported": true + "name": "gpt-3.5-turbo-1106", + "tool_call_supported": true, + "vision_supported": false + }, + { + "name": "gpt-4-turbo", + "tool_call_supported": true, + "vision_supported": true + }, + { + "name": "gpt-4-turbo-2024-04-09", + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-turbo-preview", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true + }, + { + "name": "gpt-4-0125-preview", + "tool_call_supported": true, + "vision_supported": true + }, + { + "name": "gpt-4-1106-preview", + "tool_call_supported": true, + "vision_supported": true + }, + { + "name": "gpt-4", + "tool_call_supported": true, + "vision_supported": true + }, + { + "name": "gpt-4o", + "tool_call_supported": true, + "vision_supported": true + }, + { + "name": "gpt-4-0613", + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-32k", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true + }, + { + "name": "gpt-4-32k-0613", + "tool_call_supported": true, + "vision_supported": true }, { "model_name": "SparkDesk", diff --git a/templates/provider.json b/templates/provider.json index e537156..309fb82 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -1,5 +1,6 @@ { "enable-chat": true, + "enable-vision": true, "keys": { "openai": [ "sk-1234567890"