From 987b3dc4ef4ba069932217874f044070d53432da Mon Sep 17 00:00:00 2001 From: canyuan Date: Thu, 4 Jul 2024 11:08:46 +0800 Subject: [PATCH 01/13] add ollama chat --- pkg/provider/modelmgr/apis/ollamachatcmpl.py | 105 +++++++++++++++++++ pkg/provider/modelmgr/modelmgr.py | 2 +- requirements.txt | 3 +- 3 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 pkg/provider/modelmgr/apis/ollamachatcmpl.py diff --git a/pkg/provider/modelmgr/apis/ollamachatcmpl.py b/pkg/provider/modelmgr/apis/ollamachatcmpl.py new file mode 100644 index 0000000..150a2af --- /dev/null +++ b/pkg/provider/modelmgr/apis/ollamachatcmpl.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import asyncio +import os +import typing +from typing import Union, Mapping, Any, AsyncIterator + +import async_lru +import ollama + +from .. import api, entities, errors +from ... import entities as llm_entities +from ...tools import entities as tools_entities +from ....core import app +from ....utils import image + +REQUESTER_NAME: str = "ollama-chat-completions" + + +@api.requester_class(REQUESTER_NAME) +class OllamaChatCompletions(api.LLMAPIRequester): + """Ollama平台 ChatCompletion API请求器""" + client: ollama.AsyncClient + request_cfg: dict + + def __init__(self, ap: app.Application): + super().__init__(ap) + self.ap = ap + self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME] + + async def initialize(self): + os.environ['OLLAMA_HOST'] = self.request_cfg['base-url'] + self.client = ollama.AsyncClient( + timeout=self.request_cfg['timeout'] + ) + + async def _req(self, + args: dict, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + return await self.client.chat( + **args + ) + + async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo, + user_funcs: list[tools_entities.LLMFunction] = None) -> ( + llm_entities.Message): + args: Any = self.request_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + messages: list[dict] = req_messages.copy() + for msg in messages: + if 'content' in msg and isinstance(msg["content"], list): + text_content: list = [] + image_urls: list = [] + for me in msg["content"]: + if me["type"] == "text": + text_content.append(me["text"]) + elif me["type"] == "image_url": + image_url = await self.get_base64_str(me["image_url"]['url']) + image_urls.append(image_url) + msg["content"] = "\n".join(text_content) + msg["images"] = [url.split(',')[1] for url in image_urls] + args["messages"] = messages + + resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args) + message: llm_entities.Message = await self._make_msg(resp) + return message + + async def _make_msg( + self, + chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message: + message: Any = chat_completions.pop('message', None) + if message is None: + raise ValueError("chat_completions must contain a 'message' field") + + message.update(chat_completions) + ret_msg: llm_entities.Message = llm_entities.Message(**message) + return ret_msg + + async def call( + self, + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + req_messages: list = [] + for m in messages: + msg_dict: dict = m.dict(exclude_none=True) + content: Any = msg_dict.get("content") + if isinstance(content, list): + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + msg_dict["content"] = "\n".join(part["text"] for part in content) + req_messages.append(msg_dict) + try: + return await self._closure(req_messages, model) + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + + @async_lru.alru_cache(maxsize=128) + async def get_base64_str( + self, + original_url: str, + ) -> str: + base64_image: str = await image.qq_image_url_to_base64(original_url) + return f"data:image/jpeg;base64,{base64_image}" diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 79e467a..6a221d2 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -6,7 +6,7 @@ from . import entities from ...core import app from . import token, api -from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl +from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachatcmpl FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" diff --git a/requirements.txt b/requirements.txt index 44bc285..1b554d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ pydantic websockets urllib3 psutil -async-lru \ No newline at end of file +async-lru +ollama \ No newline at end of file From 2bdc3468d1727d3d6288fe20ad4e3e1e93bc2576 Mon Sep 17 00:00:00 2001 From: canyuan Date: Mon, 8 Jul 2024 21:08:07 +0800 Subject: [PATCH 02/13] add ollama cmd --- pkg/command/cmdmgr.py | 2 +- pkg/command/operators/ollama.py | 112 ++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 pkg/command/operators/ollama.py diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 93ed8f8..1d7b92f 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -8,7 +8,7 @@ from . import entities, operator, errors from ..config import manager as cfg_mgr # 引入所有算子以便注册 -from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama class CommandManager: diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py new file mode 100644 index 0000000..3932b17 --- /dev/null +++ b/pkg/command/operators/ollama.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import json +import typing + +import ollama +from .. import operator, entities + + +@operator.operator_class( + name="ollama_list", + help="ollama模型列表", + usage="!ollama_list" +) +class OllamaListOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + content: str = '模型列表:\n' + model_list: list = ollama.list().get('models', []) + for model in model_list: + content += f"name: {model['name']}\n" + content += f"modified_at: {model['modified_at']}\n" + content += f"size: {bytes_to_mb(model['size'])}mb\n\n" + yield entities.CommandReturn(text=f"{content}") + + +def bytes_to_mb(num_bytes): + mb: float = num_bytes / 1024 / 1024 + return format(mb, '.2f') + + +@operator.operator_class( + name="ollama_show", + help="ollama模型详情", + usage="!ollama_show <模型名>" +) +class OllamaShowOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + content: str = '模型详情:\n' + try: + show: dict = ollama.show(model=context.crt_params[0]) + model_info: dict = show.get('model_info', {}) + ignore_show: str = 'too long to show...' + + for key in ['license', 'modelfile']: + show[key] = ignore_show + + for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']: + model_info[key] = ignore_show + + content += json.dumps(show, indent=4) + except ollama.ResponseError as e: + content = f"{e.error}" + + yield entities.CommandReturn(text=content) + + +@operator.operator_class( + name="ollama_pull", + help="ollama模型拉取", + usage="!ollama_pull <模型名>" +) +class OllamaPullOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + model_list: list = ollama.list().get('models', []) + if context.crt_params[0] in [model['name'] for model in model_list]: + yield entities.CommandReturn(text="模型已存在") + return + + on_progress: bool = False + progress_count: int = 0 + try: + for resp in ollama.pull(model=context.crt_params[0], stream=True): + total: typing.Any = resp.get('total') + if not on_progress: + if total is not None: + on_progress = True + yield entities.CommandReturn(text=resp.get('status')) + else: + if total is None: + on_progress = False + + completed: typing.Any = resp.get('completed') + if isinstance(completed, int) and isinstance(total, int): + percentage_completed = (completed / total) * 100 + if percentage_completed > progress_count: + progress_count += 10 + yield entities.CommandReturn( + text=f"下载进度: {completed}/{total} = {percentage_completed:.2f}%") + except ollama.ResponseError as e: + yield entities.CommandReturn(text=f"拉取失败: {e.error}") + + +@operator.operator_class( + name="ollama_del", + help="ollama模型删除", + usage="!ollama_del <模型名>" +) +class OllamaDelOperator(operator.CommandOperator): + async def execute( + self, context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + try: + ret: str = ollama.delete(model=context.crt_params[0])['status'] + except ollama.ResponseError as e: + ret = f"{e.error}" + yield entities.CommandReturn(text=ret) From e78c82e9994bc9f5baabd17c4ce8cec2a6a19fdc Mon Sep 17 00:00:00 2001 From: canyuan Date: Tue, 9 Jul 2024 16:18:34 +0800 Subject: [PATCH 03/13] mod: merge ollama cmd --- pkg/command/operators/ollama.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py index 3932b17..05c12d0 100644 --- a/pkg/command/operators/ollama.py +++ b/pkg/command/operators/ollama.py @@ -8,11 +8,11 @@ from .. import operator, entities @operator.operator_class( - name="ollama_list", - help="ollama模型列表", - usage="!ollama_list" + name="ollama", + help="ollama平台操作", + usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>" ) -class OllamaListOperator(operator.CommandOperator): +class OllamaOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: @@ -31,9 +31,10 @@ def bytes_to_mb(num_bytes): @operator.operator_class( - name="ollama_show", + name="show", help="ollama模型详情", - usage="!ollama_show <模型名>" + privilege=2, + parent_class=OllamaOperator ) class OllamaShowOperator(operator.CommandOperator): async def execute( @@ -59,9 +60,10 @@ class OllamaShowOperator(operator.CommandOperator): @operator.operator_class( - name="ollama_pull", + name="pull", help="ollama模型拉取", - usage="!ollama_pull <模型名>" + privilege=2, + parent_class=OllamaOperator ) class OllamaPullOperator(operator.CommandOperator): async def execute( @@ -97,9 +99,10 @@ class OllamaPullOperator(operator.CommandOperator): @operator.operator_class( - name="ollama_del", + name="del", help="ollama模型删除", - usage="!ollama_del <模型名>" + privilege=2, + parent_class=OllamaOperator ) class OllamaDelOperator(operator.CommandOperator): async def execute( From 21966bfb6973f942c16ec256ebb6bedaf696c060 Mon Sep 17 00:00:00 2001 From: ElvisChenML Date: Tue, 9 Jul 2024 17:04:11 +0800 Subject: [PATCH 04/13] =?UTF-8?q?fixed=20pkg\provider\entities.py\get=5Fco?= =?UTF-8?q?ntent=5Fmirai=5Fmessage=5Fchain=E4=B8=ADce.type=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E7=B1=BB=E5=9E=8B=E4=B8=8D=E6=AD=A3=E7=A1=AE=E7=9A=84?= =?UTF-8?q?=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/provider/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 82c68ad..5000072 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -95,7 +95,7 @@ class Message(pydantic.BaseModel): for ce in self.content: if ce.type == 'text': mc.append(mirai.Plain(ce.text)) - elif ce.type == 'image': + elif ce.type == 'image_url': if ce.image_url.url.startswith("http"): mc.append(mirai.Image(url=ce.image_url.url)) else: # base64 From bdb8baeddd2ef872cbe4ccdc699832efc77160d8 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 9 Jul 2024 23:37:19 +0800 Subject: [PATCH 05/13] =?UTF-8?q?perf(ollama):=20=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E5=99=A8=E5=90=8D=E7=A7=B0=E4=BB=A5=E9=80=82?= =?UTF-8?q?=E9=85=8D=E8=AF=B7=E6=B1=82=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/provider/modelmgr/apis/{ollamachatcmpl.py => ollamachat.py} | 2 +- pkg/provider/modelmgr/modelmgr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename pkg/provider/modelmgr/apis/{ollamachatcmpl.py => ollamachat.py} (98%) diff --git a/pkg/provider/modelmgr/apis/ollamachatcmpl.py b/pkg/provider/modelmgr/apis/ollamachat.py similarity index 98% rename from pkg/provider/modelmgr/apis/ollamachatcmpl.py rename to pkg/provider/modelmgr/apis/ollamachat.py index 150a2af..88edfe7 100644 --- a/pkg/provider/modelmgr/apis/ollamachatcmpl.py +++ b/pkg/provider/modelmgr/apis/ollamachat.py @@ -14,7 +14,7 @@ from ...tools import entities as tools_entities from ....core import app from ....utils import image -REQUESTER_NAME: str = "ollama-chat-completions" +REQUESTER_NAME: str = "ollama-chat" @api.requester_class(REQUESTER_NAME) diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 6a221d2..cf78230 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -6,7 +6,7 @@ from . import entities from ...core import app from . import token, api -from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachatcmpl +from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" From 3dc413638bc22ae68e3b0157b55cbcb8a2262954 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 9 Jul 2024 23:37:34 +0800 Subject: [PATCH 06/13] =?UTF-8?q?feat(ollama):=20=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../m010_ollama_requester_config.py | 23 +++++++++++++++++++ pkg/core/stages/migrate.py | 1 + templates/provider.json | 5 ++++ 3 files changed, 29 insertions(+) create mode 100644 pkg/config/migrations/m010_ollama_requester_config.py diff --git a/pkg/config/migrations/m010_ollama_requester_config.py b/pkg/config/migrations/m010_ollama_requester_config.py new file mode 100644 index 0000000..56e4966 --- /dev/null +++ b/pkg/config/migrations/m010_ollama_requester_config.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("ollama-requester-config", 10) +class MsgTruncatorConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return 'ollama-chat' not in self.ap.provider_cfg.data['requester'] + + async def run(self): + """执行迁移""" + + self.ap.provider_cfg.data['requester']['ollama-chat'] = { + "base-url": "http://127.0.0.1:11434", + "args": {}, + "timeout": 600 + } + + await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 2ad1e97..862d90a 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -6,6 +6,7 @@ from .. import stage, app from ...config import migration from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg +from ...config.migrations import m010_ollama_requester_config @stage.stage_class("MigrationStage") diff --git a/templates/provider.json b/templates/provider.json index 309fb82..32878fe 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -37,6 +37,11 @@ "base-url": "https://api.deepseek.com", "args": {}, "timeout": 120 + }, + "ollama-chat": { + "base-url": "http://127.0.0.1:11434", + "args": {}, + "timeout": 600 } }, "model": "gpt-3.5-turbo", From 7c06141ce2e0e6238a89cd3113b203049611b586 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 10 Jul 2024 00:07:32 +0800 Subject: [PATCH 07/13] =?UTF-8?q?perf(ollama):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=91=BD=E4=BB=A4=E6=98=BE=E7=A4=BA=E7=BB=86=E8=8A=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operators/ollama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py index 05c12d0..db47918 100644 --- a/pkg/command/operators/ollama.py +++ b/pkg/command/operators/ollama.py @@ -21,8 +21,8 @@ class OllamaOperator(operator.CommandOperator): for model in model_list: content += f"name: {model['name']}\n" content += f"modified_at: {model['modified_at']}\n" - content += f"size: {bytes_to_mb(model['size'])}mb\n\n" - yield entities.CommandReturn(text=f"{content}") + content += f"size: {bytes_to_mb(model['size'])}MB\n\n" + yield entities.CommandReturn(text=f"{content.strip()}") def bytes_to_mb(num_bytes): @@ -56,7 +56,7 @@ class OllamaShowOperator(operator.CommandOperator): except ollama.ResponseError as e: content = f"{e.error}" - yield entities.CommandReturn(text=content) + yield entities.CommandReturn(text=content.strip()) @operator.operator_class( @@ -93,7 +93,7 @@ class OllamaPullOperator(operator.CommandOperator): if percentage_completed > progress_count: progress_count += 10 yield entities.CommandReturn( - text=f"下载进度: {completed}/{total} = {percentage_completed:.2f}%") + text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)") except ollama.ResponseError as e: yield entities.CommandReturn(text=f"拉取失败: {e.error}") From 5bebe01dd068de2c1b697677a6a5bf8ef82df48d Mon Sep 17 00:00:00 2001 From: Junyan Qin <1010553892@qq.com> Date: Sat, 13 Jul 2024 09:15:18 +0800 Subject: [PATCH 08/13] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1af6121..b638f58 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Static Badge - + Static Badge From 70583f5ba0c2b8f2f1e79b883c3b89d346687e58 Mon Sep 17 00:00:00 2001 From: ElvisChenML Date: Thu, 25 Jul 2024 16:14:24 +0800 Subject: [PATCH 09/13] =?UTF-8?q?Fixed=20aiocqhttp=20mirai.Voice=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E6=97=A0=E6=B3=95=E6=AD=A3=E7=A1=AE=E4=BC=A0=E9=80=92?= =?UTF-8?q?url=E5=8F=8Abase64=E7=9A=84=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/platform/sources/aiocqhttp.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 4946a1c..ebdd56e 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -47,7 +47,16 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif type(msg) is mirai.Face: msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) elif type(msg) is mirai.Voice: - msg_list.append(aiocqhttp.MessageSegment.record(msg.path)) + arg = '' + if msg.base64: + arg = msg.base64 + msg_list.append(aiocqhttp.MessageSegment.record(f"base64://{arg}")) + elif msg.url: + arg = msg.url + msg_list.append(aiocqhttp.MessageSegment.record(arg)) + elif msg.path: + arg = msg.path + msg_list.append(aiocqhttp.MessageSegment.record(msg.path)) elif type(msg) is forward.Forward: for node in msg.node_list: From 68ddb3a6e1d75572eff772c1409511113a87c8ef Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jul 2024 15:46:09 +0800 Subject: [PATCH 10/13] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20model=20?= =?UTF-8?q?=E5=91=BD=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/cmdmgr.py | 2 +- pkg/command/operators/model.py | 86 +++++++++++++++++++++++++++++++++ pkg/command/operators/ollama.py | 36 ++++++++------ 3 files changed, 108 insertions(+), 16 deletions(-) create mode 100644 pkg/command/operators/model.py diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index 1d7b92f..8d442fd 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -8,7 +8,7 @@ from . import entities, operator, errors from ..config import manager as cfg_mgr # 引入所有算子以便注册 -from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model class CommandManager: diff --git a/pkg/command/operators/model.py b/pkg/command/operators/model.py new file mode 100644 index 0000000..692e272 --- /dev/null +++ b/pkg/command/operators/model.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + +@operator.operator_class( + name="model", + help='显示和切换模型列表', + usage='!model\n!model show <模型名>\n!model set <模型名>', + privilege=2 +) +class ModelOperator(operator.CommandOperator): + """Model命令""" + + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + content = '模型列表:\n' + + model_list = self.ap.model_mgr.model_list + + for model in model_list: + content += f"\n名称: {model.name}\n" + content += f"请求器: {model.requester.name}\n" + + content += f"\n当前对话使用模型: {context.query.use_model.name}\n" + content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n" + + yield entities.CommandReturn(text=content.strip()) + + +@operator.operator_class( + name="show", + help='显示模型详情', + privilege=2, + parent_class=ModelOperator +) +class ModelShowOperator(operator.CommandOperator): + """Model Show命令""" + + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + model_name = context.crt_params[0] + + model = None + for _model in self.ap.model_mgr.model_list: + if model_name == _model.name: + model = _model + break + + if model is None: + yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) + else: + content = f"模型详情\n" + content += f"名称: {model.name}\n" + if model.model_name is not None: + content += f"请求模型名称: {model.model_name}\n" + content += f"请求器: {model.requester.name}\n" + content += f"密钥组: {model.token_mgr.provider}\n" + content += f"支持视觉: {model.vision_supported}\n" + content += f"支持工具: {model.tool_call_supported}\n" + + yield entities.CommandReturn(text=content.strip()) + +@operator.operator_class( + name="set", + help='设置默认使用模型', + privilege=2, + parent_class=ModelOperator +) +class ModelSetOperator(operator.CommandOperator): + """Model Set命令""" + + async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: + model_name = context.crt_params[0] + + model = None + for _model in self.ap.model_mgr.model_list: + if model_name == _model.name: + model = _model + break + + if model is None: + yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}")) + else: + self.ap.provider_cfg.data['model'] = model_name + await self.ap.provider_cfg.dump_config() + yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效") diff --git a/pkg/command/operators/ollama.py b/pkg/command/operators/ollama.py index db47918..f5ed382 100644 --- a/pkg/command/operators/ollama.py +++ b/pkg/command/operators/ollama.py @@ -2,9 +2,10 @@ from __future__ import annotations import json import typing +import traceback import ollama -from .. import operator, entities +from .. import operator, entities, errors @operator.operator_class( @@ -16,13 +17,16 @@ class OllamaOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - content: str = '模型列表:\n' - model_list: list = ollama.list().get('models', []) - for model in model_list: - content += f"name: {model['name']}\n" - content += f"modified_at: {model['modified_at']}\n" - content += f"size: {bytes_to_mb(model['size'])}MB\n\n" - yield entities.CommandReturn(text=f"{content.strip()}") + try: + content: str = '模型列表:\n' + model_list: list = ollama.list().get('models', []) + for model in model_list: + content += f"名称: {model['name']}\n" + content += f"修改时间: {model['modified_at']}\n" + content += f"大小: {bytes_to_mb(model['size'])}MB\n\n" + yield entities.CommandReturn(text=f"{content.strip()}") + except ollama.ResponseError as e: + yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) def bytes_to_mb(num_bytes): @@ -53,11 +57,9 @@ class OllamaShowOperator(operator.CommandOperator): model_info[key] = ignore_show content += json.dumps(show, indent=4) + yield entities.CommandReturn(text=content.strip()) except ollama.ResponseError as e: - content = f"{e.error}" - - yield entities.CommandReturn(text=content.strip()) - + yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常")) @operator.operator_class( name="pull", @@ -69,9 +71,13 @@ class OllamaPullOperator(operator.CommandOperator): async def execute( self, context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - model_list: list = ollama.list().get('models', []) - if context.crt_params[0] in [model['name'] for model in model_list]: - yield entities.CommandReturn(text="模型已存在") + try: + model_list: list = ollama.list().get('models', []) + if context.crt_params[0] in [model['name'] for model in model_list]: + yield entities.CommandReturn(text="模型已存在") + return + except ollama.ResponseError as e: + yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常")) return on_progress: bool = False From 48cc3656bdd85148eb70253e8205e41a62e8b80b Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jul 2024 16:01:58 +0800 Subject: [PATCH 11/13] =?UTF-8?q?feat:=20=E5=85=81=E8=AE=B8=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=91=BD=E4=BB=A4=E5=89=8D=E7=BC=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operators/version.py | 2 +- pkg/{config => core}/migration.py | 2 +- pkg/{config => core}/migrations/__init__.py | 0 .../m001_sensitive_word_migration.py | 0 .../m002_openai_config_migration.py | 0 ...m003_anthropic_requester_cfg_completion.py | 0 .../m004_moonshot_cfg_completion.py | 0 .../m005_deepseek_cfg_completion.py | 0 .../migrations/m006_vision_config.py | 0 .../migrations/m007_qcg_center_url.py | 0 .../m008_ad_fixwin_config_migrate.py | 0 .../migrations/m009_msg_truncator_cfg.py | 0 .../m010_ollama_requester_config.py | 0 .../migrations/m011_command_prefix_config.py | 21 +++++++++++++++++++ pkg/core/stages/migrate.py | 8 +++---- pkg/pipeline/process/process.py | 4 +++- templates/command.json | 6 +++++- 17 files changed, 35 insertions(+), 8 deletions(-) rename pkg/{config => core}/migration.py (97%) rename pkg/{config => core}/migrations/__init__.py (100%) rename pkg/{config => core}/migrations/m001_sensitive_word_migration.py (100%) rename pkg/{config => core}/migrations/m002_openai_config_migration.py (100%) rename pkg/{config => core}/migrations/m003_anthropic_requester_cfg_completion.py (100%) rename pkg/{config => core}/migrations/m004_moonshot_cfg_completion.py (100%) rename pkg/{config => core}/migrations/m005_deepseek_cfg_completion.py (100%) rename pkg/{config => core}/migrations/m006_vision_config.py (100%) rename pkg/{config => core}/migrations/m007_qcg_center_url.py (100%) rename pkg/{config => core}/migrations/m008_ad_fixwin_config_migrate.py (100%) rename pkg/{config => core}/migrations/m009_msg_truncator_cfg.py (100%) rename pkg/{config => core}/migrations/m010_ollama_requester_config.py (100%) create mode 100644 pkg/core/migrations/m011_command_prefix_config.py diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py index ed248db..a5d7a81 100644 --- a/pkg/command/operators/version.py +++ b/pkg/command/operators/version.py @@ -20,7 +20,7 @@ class VersionCommand(operator.CommandOperator): try: if await self.ap.ver_mgr.is_new_version_available(): - reply_str += "\n\n有新版本可用, 使用 !update 更新" + reply_str += "\n\n有新版本可用。" except: pass diff --git a/pkg/config/migration.py b/pkg/core/migration.py similarity index 97% rename from pkg/config/migration.py rename to pkg/core/migration.py index e84a59c..2c5c759 100644 --- a/pkg/config/migration.py +++ b/pkg/core/migration.py @@ -3,7 +3,7 @@ from __future__ import annotations import abc import typing -from ..core import app +from . import app preregistered_migrations: list[typing.Type[Migration]] = [] diff --git a/pkg/config/migrations/__init__.py b/pkg/core/migrations/__init__.py similarity index 100% rename from pkg/config/migrations/__init__.py rename to pkg/core/migrations/__init__.py diff --git a/pkg/config/migrations/m001_sensitive_word_migration.py b/pkg/core/migrations/m001_sensitive_word_migration.py similarity index 100% rename from pkg/config/migrations/m001_sensitive_word_migration.py rename to pkg/core/migrations/m001_sensitive_word_migration.py diff --git a/pkg/config/migrations/m002_openai_config_migration.py b/pkg/core/migrations/m002_openai_config_migration.py similarity index 100% rename from pkg/config/migrations/m002_openai_config_migration.py rename to pkg/core/migrations/m002_openai_config_migration.py diff --git a/pkg/config/migrations/m003_anthropic_requester_cfg_completion.py b/pkg/core/migrations/m003_anthropic_requester_cfg_completion.py similarity index 100% rename from pkg/config/migrations/m003_anthropic_requester_cfg_completion.py rename to pkg/core/migrations/m003_anthropic_requester_cfg_completion.py diff --git a/pkg/config/migrations/m004_moonshot_cfg_completion.py b/pkg/core/migrations/m004_moonshot_cfg_completion.py similarity index 100% rename from pkg/config/migrations/m004_moonshot_cfg_completion.py rename to pkg/core/migrations/m004_moonshot_cfg_completion.py diff --git a/pkg/config/migrations/m005_deepseek_cfg_completion.py b/pkg/core/migrations/m005_deepseek_cfg_completion.py similarity index 100% rename from pkg/config/migrations/m005_deepseek_cfg_completion.py rename to pkg/core/migrations/m005_deepseek_cfg_completion.py diff --git a/pkg/config/migrations/m006_vision_config.py b/pkg/core/migrations/m006_vision_config.py similarity index 100% rename from pkg/config/migrations/m006_vision_config.py rename to pkg/core/migrations/m006_vision_config.py diff --git a/pkg/config/migrations/m007_qcg_center_url.py b/pkg/core/migrations/m007_qcg_center_url.py similarity index 100% rename from pkg/config/migrations/m007_qcg_center_url.py rename to pkg/core/migrations/m007_qcg_center_url.py diff --git a/pkg/config/migrations/m008_ad_fixwin_config_migrate.py b/pkg/core/migrations/m008_ad_fixwin_config_migrate.py similarity index 100% rename from pkg/config/migrations/m008_ad_fixwin_config_migrate.py rename to pkg/core/migrations/m008_ad_fixwin_config_migrate.py diff --git a/pkg/config/migrations/m009_msg_truncator_cfg.py b/pkg/core/migrations/m009_msg_truncator_cfg.py similarity index 100% rename from pkg/config/migrations/m009_msg_truncator_cfg.py rename to pkg/core/migrations/m009_msg_truncator_cfg.py diff --git a/pkg/config/migrations/m010_ollama_requester_config.py b/pkg/core/migrations/m010_ollama_requester_config.py similarity index 100% rename from pkg/config/migrations/m010_ollama_requester_config.py rename to pkg/core/migrations/m010_ollama_requester_config.py diff --git a/pkg/core/migrations/m011_command_prefix_config.py b/pkg/core/migrations/m011_command_prefix_config.py new file mode 100644 index 0000000..6a9e111 --- /dev/null +++ b/pkg/core/migrations/m011_command_prefix_config.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("command-prefix-config", 11) +class CommandPrefixConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return 'command-prefix' not in self.ap.command_cfg.data + + async def run(self): + """执行迁移""" + + self.ap.command_cfg.data['command-prefix'] = [ + "!", "!" + ] + + await self.ap.command_cfg.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 862d90a..54fcf60 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -3,10 +3,10 @@ from __future__ import annotations import importlib from .. import stage, app -from ...config import migration -from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion -from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg -from ...config.migrations import m010_ollama_requester_config +from .. import migration +from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion +from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg +from ..migrations import m010_ollama_requester_config, m011_command_prefix_config @stage.stage_class("MigrationStage") diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index e58d15e..362ece0 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -42,7 +42,9 @@ class Processor(stage.PipelineStage): self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}") async def generator(): - if message_text.startswith('!') or message_text.startswith('!'): + cmd_prefix = self.ap.command_cfg.data['command-prefix'] + + if any(message_text.startswith(prefix) for prefix in cmd_prefix): async for result in self.cmd_handler.handle(query): yield result else: diff --git a/templates/command.json b/templates/command.json index 55360fc..7c93f64 100644 --- a/templates/command.json +++ b/templates/command.json @@ -1,3 +1,7 @@ { - "privilege": {} + "privilege": {}, + "command-prefix": [ + "!", + "!" + ] } \ No newline at end of file From 8cad4089a78625c5c874eb7809b9f799b4b5350b Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jul 2024 18:45:27 +0800 Subject: [PATCH 12/13] =?UTF-8?q?feat:=20runner=20=E5=B1=82=E6=8A=BD?= =?UTF-8?q?=E8=B1=A1=20(#839)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/app.py | 3 + pkg/core/migrations/m012_runner_config.py | 19 ++++++ pkg/core/stages/build_app.py | 6 ++ pkg/core/stages/migrate.py | 2 +- pkg/pipeline/process/handlers/chat.py | 67 ++-------------------- pkg/provider/runner.py | 40 +++++++++++++ pkg/provider/runnermgr.py | 27 +++++++++ pkg/provider/runners/__init__.py | 0 pkg/provider/runners/localagent.py | 70 +++++++++++++++++++++++ templates/provider.json | 3 +- 10 files changed, 172 insertions(+), 65 deletions(-) create mode 100644 pkg/core/migrations/m012_runner_config.py create mode 100644 pkg/provider/runner.py create mode 100644 pkg/provider/runnermgr.py create mode 100644 pkg/provider/runners/__init__.py create mode 100644 pkg/provider/runners/localagent.py diff --git a/pkg/core/app.py b/pkg/core/app.py index 25fc49f..e6e25ea 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -9,6 +9,7 @@ from ..provider.session import sessionmgr as llm_session_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.tools import toolmgr as llm_tool_mgr +from ..provider import runnermgr from ..config import manager as config_mgr from ..audit.center import v2 as center_mgr from ..command import cmdmgr @@ -33,6 +34,8 @@ class Application: tool_mgr: llm_tool_mgr.ToolManager = None + runner_mgr: runnermgr.RunnerManager = None + # ======= 配置管理器 ======= command_cfg: config_mgr.ConfigManager = None diff --git a/pkg/core/migrations/m012_runner_config.py b/pkg/core/migrations/m012_runner_config.py new file mode 100644 index 0000000..fa236bb --- /dev/null +++ b/pkg/core/migrations/m012_runner_config.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("runner-config", 12) +class RunnerConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return 'runner' not in self.ap.provider_cfg.data + + async def run(self): + """执行迁移""" + + self.ap.provider_cfg.data['runner'] = 'local-agent' + + await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index 85bd48a..c0e731c 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -13,6 +13,7 @@ from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr +from ...provider import runnermgr from ...platform import manager as im_mgr @stage.stage_class("BuildAppStage") @@ -81,6 +82,11 @@ class BuildAppStage(stage.BootingStage): llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst + + runner_mgr_inst = runnermgr.RunnerManager(ap) + await runner_mgr_inst.initialize() + ap.runner_mgr = runner_mgr_inst + im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() ap.platform_mgr = im_mgr_inst diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 54fcf60..92735c9 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -6,7 +6,7 @@ from .. import stage, app from .. import migration from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg -from ..migrations import m010_ollama_requester_config, m011_command_prefix_config +from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config @stage.stage_class("MigrationStage") diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 8a98e30..cb8899b 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -10,7 +10,7 @@ import mirai from .. import handler from ... import entities from ....core import entities as core_entities -from ....provider import entities as llm_entities +from ....provider import entities as llm_entities, runnermgr from ....plugin import events @@ -71,7 +71,9 @@ class ChatMessageHandler(handler.MessageHandler): try: - async for result in self.runner(query): + runner = self.ap.runner_mgr.get_runner() + + async for result in runner.run(query): query.resp_messages.append(result) self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') @@ -108,64 +110,3 @@ class ChatMessageHandler(handler.MessageHandler): response_seconds=int(time.time() - start_time), retry_times=-1, ) - - async def runner( - self, - query: core_entities.Query, - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """执行一个请求处理过程中的LLM接口请求、函数调用的循环 - - 这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器 - """ - await query.use_model.requester.preprocess(query) - - pending_tool_calls = [] - - req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] - - # 首次请求 - msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg) - - # 持续请求,只要还有待处理的工具调用就继续处理调用 - while pending_tool_calls: - for tool_call in pending_tool_calls: - try: - func = tool_call.function - - parameters = json.loads(func.arguments) - - func_ret = await self.ap.tool_mgr.execute_func_call( - query, func.name, parameters - ) - - msg = llm_entities.Message( - role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id - ) - - yield msg - - req_messages.append(msg) - except Exception as e: - # 工具调用出错,添加一个报错信息到 req_messages - err_msg = llm_entities.Message( - role="tool", content=f"err: {e}", tool_call_id=tool_call.id - ) - - yield err_msg - - req_messages.append(err_msg) - - # 处理完所有调用,再次请求 - msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg) diff --git a/pkg/provider/runner.py b/pkg/provider/runner.py new file mode 100644 index 0000000..5a5cf6e --- /dev/null +++ b/pkg/provider/runner.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import abc +import typing + +from ..core import app, entities as core_entities +from . import entities as llm_entities + + +preregistered_runners: list[typing.Type[RequestRunner]] = [] + +def runner_class(name: str): + """注册一个请求运行器 + """ + def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]: + cls.name = name + preregistered_runners.append(cls) + return cls + + return decorator + + +class RequestRunner(abc.ABC): + """请求运行器 + """ + name: str = None + + ap: app.Application + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + pass + + @abc.abstractmethod + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + """运行请求 + """ + pass diff --git a/pkg/provider/runnermgr.py b/pkg/provider/runnermgr.py new file mode 100644 index 0000000..c1c1a45 --- /dev/null +++ b/pkg/provider/runnermgr.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from . import runner +from ..core import app + +from .runners import localagent + + +class RunnerManager: + + ap: app.Application + + using_runner: runner.RequestRunner + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + + for r in runner.preregistered_runners: + if r.name == self.ap.provider_cfg.data['runner']: + self.using_runner = r(self.ap) + await self.using_runner.initialize() + break + + def get_runner(self) -> runner.RequestRunner: + return self.using_runner diff --git a/pkg/provider/runners/__init__.py b/pkg/provider/runners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py new file mode 100644 index 0000000..84cda56 --- /dev/null +++ b/pkg/provider/runners/localagent.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import json +import typing + +from .. import runner +from ...core import app, entities as core_entities +from .. import entities as llm_entities + + +@runner.runner_class("local-agent") +class LocalAgentRunner(runner.RequestRunner): + """本地Agent请求运行器 + """ + + async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: + """运行请求 + """ + await query.use_model.requester.preprocess(query) + + pending_tool_calls = [] + + req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + + # 首次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) + + # 持续请求,只要还有待处理的工具调用就继续处理调用 + while pending_tool_calls: + for tool_call in pending_tool_calls: + try: + func = tool_call.function + + parameters = json.loads(func.arguments) + + func_ret = await self.ap.tool_mgr.execute_func_call( + query, func.name, parameters + ) + + msg = llm_entities.Message( + role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + ) + + yield msg + + req_messages.append(msg) + except Exception as e: + # 工具调用出错,添加一个报错信息到 req_messages + err_msg = llm_entities.Message( + role="tool", content=f"err: {e}", tool_call_id=tool_call.id + ) + + yield err_msg + + req_messages.append(err_msg) + + # 处理完所有调用,再次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) diff --git a/templates/provider.json b/templates/provider.json index 32878fe..b7ec7fb 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -48,5 +48,6 @@ "prompt-mode": "normal", "prompt": { "default": "" - } + }, + "runner": "local-agent" } \ No newline at end of file From 1c5f06d9a916c1386a05ecf6fefc951765036153 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jul 2024 20:23:52 +0800 Subject: [PATCH 13/13] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20reply=20?= =?UTF-8?q?=E5=92=8C=20send=5Fmessage=20=E4=B8=A4=E4=B8=AA=E6=8F=92?= =?UTF-8?q?=E4=BB=B6api=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/platform/manager.py | 6 ++--- pkg/plugin/context.py | 53 ++++++++++++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index b5f6986..b969642 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -146,9 +146,9 @@ class PlatformManager: if len(self.adapters) == 0: self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') - async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True): + async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter): - if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): + if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): msg.insert( 0, @@ -160,7 +160,7 @@ class PlatformManager: await adapter.reply_message( event, msg, - quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False + quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False ) async def run(self): diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index b0e2ef9..42cb6be 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing import abc import pydantic +import mirai from . import events from ..provider.tools import entities as tools_entities @@ -165,11 +166,54 @@ class EventContext: } """ + # ========== 插件可调用的 API ========== + def add_return(self, key: str, ret): """添加返回值""" if key not in self.__return_value__: self.__return_value__[key] = [] self.__return_value__[key].append(ret) + + async def reply(self, message_chain: mirai.MessageChain): + """回复此次消息请求 + + Args: + message_chain (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链 + """ + await self.host.ap.platform_mgr.send( + event=self.event.query.message_event, + msg=message_chain, + adapter=self.event.query.adapter, + ) + + async def send_message( + self, + target_type: str, + target_id: str, + message: mirai.MessageChain + ): + """主动发送消息 + + Args: + target_type (str): 目标类型,`person`或`group` + target_id (str): 目标ID + message (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链 + """ + await self.event.query.adapter.send_message( + target_type=target_type, + target_id=target_id, + message=message + ) + + def prevent_postorder(self): + """阻止后续插件执行""" + self.__prevent_postorder__ = True + + def prevent_default(self): + """阻止默认行为""" + self.__prevent_default__ = True + + # ========== 以下是内部保留方法,插件不应调用 ========== def get_return(self, key: str) -> list: """获取key的所有返回值""" @@ -183,14 +227,6 @@ class EventContext: return self.__return_value__[key][0] return None - def prevent_default(self): - """阻止默认行为""" - self.__prevent_default__ = True - - def prevent_postorder(self): - """阻止后续插件执行""" - self.__prevent_postorder__ = True - def is_prevented_default(self): """是否阻止默认行为""" return self.__prevent_default__ @@ -198,6 +234,7 @@ class EventContext: def is_prevented_postorder(self): """是否阻止后序插件执行""" return self.__prevent_postorder__ + def __init__(self, host: APIHost, event: events.BaseEventModel):