mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
refactor: 使插件更新过程全异步
This commit is contained in:
parent
68184b0e47
commit
709b86b724
|
@ -19,6 +19,8 @@ required_deps = {
|
||||||
"quart_cors": "quart-cors",
|
"quart_cors": "quart-cors",
|
||||||
"sqlalchemy": "sqlalchemy[asyncio]",
|
"sqlalchemy": "sqlalchemy[asyncio]",
|
||||||
"aiosqlite": "aiosqlite",
|
"aiosqlite": "aiosqlite",
|
||||||
|
"aiofiles": "aiofiles",
|
||||||
|
"aioshutil": "aioshutil",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
import datetime
|
import datetime
|
||||||
|
import traceback
|
||||||
|
|
||||||
from . import app
|
from . import app
|
||||||
|
|
||||||
|
@ -85,6 +86,9 @@ class TaskWrapper:
|
||||||
task: asyncio.Task
|
task: asyncio.Task
|
||||||
"""任务"""
|
"""任务"""
|
||||||
|
|
||||||
|
task_stack: list = None
|
||||||
|
"""任务堆栈"""
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
"""应用实例"""
|
"""应用实例"""
|
||||||
|
|
||||||
|
@ -111,7 +115,10 @@ class TaskWrapper:
|
||||||
|
|
||||||
def assume_exception(self):
|
def assume_exception(self):
|
||||||
try:
|
try:
|
||||||
return self.task.exception()
|
exception = self.task.exception()
|
||||||
|
if self.task_stack is None:
|
||||||
|
self.task_stack = self.task.get_stack()
|
||||||
|
return exception
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -123,6 +130,13 @@ class TaskWrapper:
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
|
|
||||||
|
exception_traceback = None
|
||||||
|
if self.assume_exception() is not None:
|
||||||
|
exception_traceback = 'Traceback (most recent call last):\n'
|
||||||
|
|
||||||
|
for frame in self.task_stack:
|
||||||
|
exception_traceback += f" File \"{frame.f_code.co_filename}\", line {frame.f_lineno}, in {frame.f_code.co_name}\n"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"task_type": self.task_type,
|
"task_type": self.task_type,
|
||||||
|
@ -133,8 +147,9 @@ class TaskWrapper:
|
||||||
"runtime": {
|
"runtime": {
|
||||||
"done": self.task.done(),
|
"done": self.task.done(),
|
||||||
"state": self.task._state,
|
"state": self.task._state,
|
||||||
"exception": self.assume_exception(),
|
"exception": self.assume_exception().__str__() if self.assume_exception() is not None else None,
|
||||||
"result": self.assume_result(),
|
"exception_traceback": exception_traceback,
|
||||||
|
"result": self.assume_result().__str__() if self.assume_result() is not None else None,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,10 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
import requests
|
import aiohttp
|
||||||
|
import aiofiles
|
||||||
|
import aiofiles.os as aiofiles_os
|
||||||
|
import aioshutil
|
||||||
|
|
||||||
from .. import installer, errors
|
from .. import installer, errors
|
||||||
from ...utils import pkgmgr
|
from ...utils import pkgmgr
|
||||||
|
@ -29,65 +32,65 @@ class GitHubRepoInstaller(installer.PluginInstaller):
|
||||||
return repo[0].split("/")
|
return repo[0].split("/")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def download_plugin_source_code(self, repo_url: str, target_path: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder()) -> str:
|
||||||
|
"""下载插件源码(全异步)"""
|
||||||
|
|
||||||
async def download_plugin_source_code(self, repo_url: str, target_path: str) -> str:
|
|
||||||
"""下载插件源码"""
|
|
||||||
# 检查源类型
|
|
||||||
|
|
||||||
# 提取 username/repo , 正则表达式
|
# 提取 username/repo , 正则表达式
|
||||||
repo = self.get_github_plugin_repo_label(repo_url)
|
repo = self.get_github_plugin_repo_label(repo_url)
|
||||||
|
|
||||||
target_path += repo[1]
|
target_path += repo[1]
|
||||||
|
|
||||||
if repo is not None: # github
|
if repo is None:
|
||||||
self.ap.logger.debug("正在下载源码...")
|
|
||||||
|
|
||||||
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
|
|
||||||
|
|
||||||
zip_resp = requests.get(
|
|
||||||
url=zipball_url, proxies=self.ap.proxy_mgr.get_forward_proxies(), stream=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if zip_resp.status_code != 200:
|
|
||||||
raise Exception("下载源码失败: {}".format(zip_resp.text))
|
|
||||||
|
|
||||||
if os.path.exists("temp/" + target_path):
|
|
||||||
shutil.rmtree("temp/" + target_path)
|
|
||||||
|
|
||||||
if os.path.exists(target_path):
|
|
||||||
shutil.rmtree(target_path)
|
|
||||||
|
|
||||||
os.makedirs("temp/" + target_path)
|
|
||||||
|
|
||||||
with open("temp/" + target_path + "/source.zip", "wb") as f:
|
|
||||||
for chunk in zip_resp.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
self.ap.logger.debug("解压中...")
|
|
||||||
|
|
||||||
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
|
|
||||||
zip_ref.extractall("temp/" + target_path)
|
|
||||||
os.remove("temp/" + target_path + "/source.zip")
|
|
||||||
|
|
||||||
# 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo
|
|
||||||
import glob
|
|
||||||
|
|
||||||
# 获取解压后的文件夹名
|
|
||||||
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
|
|
||||||
|
|
||||||
# 复制到 plugins/repo
|
|
||||||
shutil.copytree(unzip_dir, target_path + "/")
|
|
||||||
|
|
||||||
# 删除解压后的文件夹
|
|
||||||
shutil.rmtree(unzip_dir)
|
|
||||||
|
|
||||||
self.ap.logger.debug("源码下载完成。")
|
|
||||||
else:
|
|
||||||
raise errors.PluginInstallerError('仅支持GitHub仓库地址')
|
raise errors.PluginInstallerError('仅支持GitHub仓库地址')
|
||||||
|
|
||||||
|
self.ap.logger.debug("正在下载源码...")
|
||||||
|
task_context.trace("下载源码...", "download-plugin-source-code")
|
||||||
|
|
||||||
|
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
|
||||||
|
|
||||||
|
zip_resp: bytes = None
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
url=zipball_url,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=300)
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise errors.PluginInstallerError(f"下载源码失败: {resp.text}")
|
||||||
|
|
||||||
|
zip_resp = await resp.read()
|
||||||
|
|
||||||
|
if await aiofiles_os.path.exists("temp/" + target_path):
|
||||||
|
await aioshutil.rmtree("temp/" + target_path)
|
||||||
|
|
||||||
|
if await aiofiles_os.path.exists(target_path):
|
||||||
|
await aioshutil.rmtree(target_path)
|
||||||
|
|
||||||
|
await aiofiles_os.makedirs("temp/" + target_path)
|
||||||
|
|
||||||
|
async with aiofiles.open("temp/" + target_path + "/source.zip", "wb") as f:
|
||||||
|
await f.write(zip_resp)
|
||||||
|
|
||||||
|
self.ap.logger.debug("解压中...")
|
||||||
|
task_context.trace("解压中...", "unzip-plugin-source-code")
|
||||||
|
|
||||||
|
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
|
||||||
|
zip_ref.extractall("temp/" + target_path)
|
||||||
|
await aiofiles_os.remove("temp/" + target_path + "/source.zip")
|
||||||
|
|
||||||
|
import glob
|
||||||
|
|
||||||
|
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
|
||||||
|
|
||||||
|
await aioshutil.copytree(unzip_dir, target_path + "/")
|
||||||
|
|
||||||
|
await aioshutil.rmtree(unzip_dir)
|
||||||
|
|
||||||
|
self.ap.logger.debug("源码下载完成。")
|
||||||
|
|
||||||
return repo[1]
|
return repo[1]
|
||||||
|
|
||||||
async def install_requirements(self, path: str):
|
async def install_requirements(self, path: str):
|
||||||
if os.path.exists(path + "/requirements.txt"):
|
if os.path.exists(path + "/requirements.txt"):
|
||||||
pkgmgr.install_requirements(path + "/requirements.txt")
|
pkgmgr.install_requirements(path + "/requirements.txt")
|
||||||
|
@ -101,7 +104,7 @@ class GitHubRepoInstaller(installer.PluginInstaller):
|
||||||
"""
|
"""
|
||||||
task_context.trace("下载插件源码...", "install-plugin")
|
task_context.trace("下载插件源码...", "install-plugin")
|
||||||
|
|
||||||
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/")
|
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/", task_context)
|
||||||
|
|
||||||
task_context.trace("安装插件依赖...", "install-plugin")
|
task_context.trace("安装插件依赖...", "install-plugin")
|
||||||
|
|
||||||
|
|
|
@ -19,3 +19,5 @@ quart
|
||||||
sqlalchemy[asyncio]
|
sqlalchemy[asyncio]
|
||||||
aiosqlite
|
aiosqlite
|
||||||
quart-cors
|
quart-cors
|
||||||
|
aiofiles
|
||||||
|
aioshutil
|
Loading…
Reference in New Issue
Block a user