mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
commit
9d91c13b12
|
@ -8,7 +8,7 @@ from . import entities, operator, errors
|
|||
from ..config import manager as cfg_mgr
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama
|
||||
|
||||
|
||||
class CommandManager:
|
||||
|
|
115
pkg/command/operators/ollama.py
Normal file
115
pkg/command/operators/ollama.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
|
||||
import ollama
|
||||
from .. import operator, entities
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="ollama",
|
||||
help="ollama平台操作",
|
||||
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>"
|
||||
)
|
||||
class OllamaOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
content: str = '模型列表:\n'
|
||||
model_list: list = ollama.list().get('models', [])
|
||||
for model in model_list:
|
||||
content += f"name: {model['name']}\n"
|
||||
content += f"modified_at: {model['modified_at']}\n"
|
||||
content += f"size: {bytes_to_mb(model['size'])}MB\n\n"
|
||||
yield entities.CommandReturn(text=f"{content.strip()}")
|
||||
|
||||
|
||||
def bytes_to_mb(num_bytes):
|
||||
mb: float = num_bytes / 1024 / 1024
|
||||
return format(mb, '.2f')
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="show",
|
||||
help="ollama模型详情",
|
||||
privilege=2,
|
||||
parent_class=OllamaOperator
|
||||
)
|
||||
class OllamaShowOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
content: str = '模型详情:\n'
|
||||
try:
|
||||
show: dict = ollama.show(model=context.crt_params[0])
|
||||
model_info: dict = show.get('model_info', {})
|
||||
ignore_show: str = 'too long to show...'
|
||||
|
||||
for key in ['license', 'modelfile']:
|
||||
show[key] = ignore_show
|
||||
|
||||
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']:
|
||||
model_info[key] = ignore_show
|
||||
|
||||
content += json.dumps(show, indent=4)
|
||||
except ollama.ResponseError as e:
|
||||
content = f"{e.error}"
|
||||
|
||||
yield entities.CommandReturn(text=content.strip())
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="pull",
|
||||
help="ollama模型拉取",
|
||||
privilege=2,
|
||||
parent_class=OllamaOperator
|
||||
)
|
||||
class OllamaPullOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
model_list: list = ollama.list().get('models', [])
|
||||
if context.crt_params[0] in [model['name'] for model in model_list]:
|
||||
yield entities.CommandReturn(text="模型已存在")
|
||||
return
|
||||
|
||||
on_progress: bool = False
|
||||
progress_count: int = 0
|
||||
try:
|
||||
for resp in ollama.pull(model=context.crt_params[0], stream=True):
|
||||
total: typing.Any = resp.get('total')
|
||||
if not on_progress:
|
||||
if total is not None:
|
||||
on_progress = True
|
||||
yield entities.CommandReturn(text=resp.get('status'))
|
||||
else:
|
||||
if total is None:
|
||||
on_progress = False
|
||||
|
||||
completed: typing.Any = resp.get('completed')
|
||||
if isinstance(completed, int) and isinstance(total, int):
|
||||
percentage_completed = (completed / total) * 100
|
||||
if percentage_completed > progress_count:
|
||||
progress_count += 10
|
||||
yield entities.CommandReturn(
|
||||
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)")
|
||||
except ollama.ResponseError as e:
|
||||
yield entities.CommandReturn(text=f"拉取失败: {e.error}")
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="del",
|
||||
help="ollama模型删除",
|
||||
privilege=2,
|
||||
parent_class=OllamaOperator
|
||||
)
|
||||
class OllamaDelOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
ret: str = ollama.delete(model=context.crt_params[0])['status']
|
||||
except ollama.ResponseError as e:
|
||||
ret = f"{e.error}"
|
||||
yield entities.CommandReturn(text=ret)
|
23
pkg/config/migrations/m010_ollama_requester_config.py
Normal file
23
pkg/config/migrations/m010_ollama_requester_config.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("ollama-requester-config", 10)
|
||||
class MsgTruncatorConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return 'ollama-chat' not in self.ap.provider_cfg.data['requester']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
|
||||
"base-url": "http://127.0.0.1:11434",
|
||||
"args": {},
|
||||
"timeout": 600
|
||||
}
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
|
@ -6,6 +6,7 @@ from .. import stage, app
|
|||
from ...config import migration
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
||||
from ...config.migrations import m010_ollama_requester_config
|
||||
|
||||
|
||||
@stage.stage_class("MigrationStage")
|
||||
|
|
105
pkg/provider/modelmgr/apis/ollamachat.py
Normal file
105
pkg/provider/modelmgr/apis/ollamachat.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import typing
|
||||
from typing import Union, Mapping, Any, AsyncIterator
|
||||
|
||||
import async_lru
|
||||
import ollama
|
||||
|
||||
from .. import api, entities, errors
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....core import app
|
||||
from ....utils import image
|
||||
|
||||
REQUESTER_NAME: str = "ollama-chat"
|
||||
|
||||
|
||||
@api.requester_class(REQUESTER_NAME)
|
||||
class OllamaChatCompletions(api.LLMAPIRequester):
|
||||
"""Ollama平台 ChatCompletion API请求器"""
|
||||
client: ollama.AsyncClient
|
||||
request_cfg: dict
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
super().__init__(ap)
|
||||
self.ap = ap
|
||||
self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME]
|
||||
|
||||
async def initialize(self):
|
||||
os.environ['OLLAMA_HOST'] = self.request_cfg['base-url']
|
||||
self.client = ollama.AsyncClient(
|
||||
timeout=self.request_cfg['timeout']
|
||||
)
|
||||
|
||||
async def _req(self,
|
||||
args: dict,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
return await self.client.chat(
|
||||
**args
|
||||
)
|
||||
|
||||
async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo,
|
||||
user_funcs: list[tools_entities.LLMFunction] = None) -> (
|
||||
llm_entities.Message):
|
||||
args: Any = self.request_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
messages: list[dict] = req_messages.copy()
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg["content"], list):
|
||||
text_content: list = []
|
||||
image_urls: list = []
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "text":
|
||||
text_content.append(me["text"])
|
||||
elif me["type"] == "image_url":
|
||||
image_url = await self.get_base64_str(me["image_url"]['url'])
|
||||
image_urls.append(image_url)
|
||||
msg["content"] = "\n".join(text_content)
|
||||
msg["images"] = [url.split(',')[1] for url in image_urls]
|
||||
args["messages"] = messages
|
||||
|
||||
resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args)
|
||||
message: llm_entities.Message = await self._make_msg(resp)
|
||||
return message
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message:
|
||||
message: Any = chat_completions.pop('message', None)
|
||||
if message is None:
|
||||
raise ValueError("chat_completions must contain a 'message' field")
|
||||
|
||||
message.update(chat_completions)
|
||||
ret_msg: llm_entities.Message = llm_entities.Message(**message)
|
||||
return ret_msg
|
||||
|
||||
async def call(
|
||||
self,
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get("content")
|
||||
if isinstance(content, list):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
return await self._closure(req_messages, model)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
|
||||
@async_lru.alru_cache(maxsize=128)
|
||||
async def get_base64_str(
|
||||
self,
|
||||
original_url: str,
|
||||
) -> str:
|
||||
base64_image: str = await image.qq_image_url_to_base64(original_url)
|
||||
return f"data:image/jpeg;base64,{base64_image}"
|
|
@ -6,7 +6,7 @@ from . import entities
|
|||
from ...core import app
|
||||
|
||||
from . import token, api
|
||||
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl
|
||||
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat
|
||||
|
||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||
|
||||
|
|
|
@ -14,4 +14,5 @@ pydantic
|
|||
websockets
|
||||
urllib3
|
||||
psutil
|
||||
async-lru
|
||||
async-lru
|
||||
ollama
|
|
@ -37,6 +37,11 @@
|
|||
"base-url": "https://api.deepseek.com",
|
||||
"args": {},
|
||||
"timeout": 120
|
||||
},
|
||||
"ollama-chat": {
|
||||
"base-url": "http://127.0.0.1:11434",
|
||||
"args": {},
|
||||
"timeout": 600
|
||||
}
|
||||
},
|
||||
"model": "gpt-3.5-turbo",
|
||||
|
|
Loading…
Reference in New Issue
Block a user