refactor: 使插件更新过程全异步

This commit is contained in:
Junyan Qin 2024-11-03 22:27:31 +08:00
parent 68184b0e47
commit 709b86b724
No known key found for this signature in database
GPG Key ID: 22FE3AFADC710CEB
4 changed files with 77 additions and 55 deletions

View File

@ -19,6 +19,8 @@ required_deps = {
"quart_cors": "quart-cors", "quart_cors": "quart-cors",
"sqlalchemy": "sqlalchemy[asyncio]", "sqlalchemy": "sqlalchemy[asyncio]",
"aiosqlite": "aiosqlite", "aiosqlite": "aiosqlite",
"aiofiles": "aiofiles",
"aioshutil": "aioshutil",
} }

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import typing import typing
import datetime import datetime
import traceback
from . import app from . import app
@ -85,6 +86,9 @@ class TaskWrapper:
task: asyncio.Task task: asyncio.Task
"""任务""" """任务"""
task_stack: list = None
"""任务堆栈"""
ap: app.Application ap: app.Application
"""应用实例""" """应用实例"""
@ -111,7 +115,10 @@ class TaskWrapper:
def assume_exception(self): def assume_exception(self):
try: try:
return self.task.exception() exception = self.task.exception()
if self.task_stack is None:
self.task_stack = self.task.get_stack()
return exception
except: except:
return None return None
@ -123,6 +130,13 @@ class TaskWrapper:
def to_dict(self) -> dict: def to_dict(self) -> dict:
exception_traceback = None
if self.assume_exception() is not None:
exception_traceback = 'Traceback (most recent call last):\n'
for frame in self.task_stack:
exception_traceback += f" File \"{frame.f_code.co_filename}\", line {frame.f_lineno}, in {frame.f_code.co_name}\n"
return { return {
"id": self.id, "id": self.id,
"task_type": self.task_type, "task_type": self.task_type,
@ -133,8 +147,9 @@ class TaskWrapper:
"runtime": { "runtime": {
"done": self.task.done(), "done": self.task.done(),
"state": self.task._state, "state": self.task._state,
"exception": self.assume_exception(), "exception": self.assume_exception().__str__() if self.assume_exception() is not None else None,
"result": self.assume_result(), "exception_traceback": exception_traceback,
"result": self.assume_result().__str__() if self.assume_result() is not None else None,
}, },
} }

View File

@ -5,7 +5,10 @@ import os
import shutil import shutil
import zipfile import zipfile
import requests import aiohttp
import aiofiles
import aiofiles.os as aiofiles_os
import aioshutil
from .. import installer, errors from .. import installer, errors
from ...utils import pkgmgr from ...utils import pkgmgr
@ -29,65 +32,65 @@ class GitHubRepoInstaller(installer.PluginInstaller):
return repo[0].split("/") return repo[0].split("/")
else: else:
return None return None
async def download_plugin_source_code(self, repo_url: str, target_path: str, task_context: taskmgr.TaskContext = taskmgr.TaskContext.placeholder()) -> str:
"""下载插件源码(全异步)"""
async def download_plugin_source_code(self, repo_url: str, target_path: str) -> str:
"""下载插件源码"""
# 检查源类型
# 提取 username/repo , 正则表达式 # 提取 username/repo , 正则表达式
repo = self.get_github_plugin_repo_label(repo_url) repo = self.get_github_plugin_repo_label(repo_url)
target_path += repo[1] target_path += repo[1]
if repo is not None: # github if repo is None:
self.ap.logger.debug("正在下载源码...")
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
zip_resp = requests.get(
url=zipball_url, proxies=self.ap.proxy_mgr.get_forward_proxies(), stream=True
)
if zip_resp.status_code != 200:
raise Exception("下载源码失败: {}".format(zip_resp.text))
if os.path.exists("temp/" + target_path):
shutil.rmtree("temp/" + target_path)
if os.path.exists(target_path):
shutil.rmtree(target_path)
os.makedirs("temp/" + target_path)
with open("temp/" + target_path + "/source.zip", "wb") as f:
for chunk in zip_resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
self.ap.logger.debug("解压中...")
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
zip_ref.extractall("temp/" + target_path)
os.remove("temp/" + target_path + "/source.zip")
# 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo
import glob
# 获取解压后的文件夹名
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
# 复制到 plugins/repo
shutil.copytree(unzip_dir, target_path + "/")
# 删除解压后的文件夹
shutil.rmtree(unzip_dir)
self.ap.logger.debug("源码下载完成。")
else:
raise errors.PluginInstallerError('仅支持GitHub仓库地址') raise errors.PluginInstallerError('仅支持GitHub仓库地址')
self.ap.logger.debug("正在下载源码...")
task_context.trace("下载源码...", "download-plugin-source-code")
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
zip_resp: bytes = None
async with aiohttp.ClientSession() as session:
async with session.get(
url=zipball_url,
timeout=aiohttp.ClientTimeout(total=300)
) as resp:
if resp.status != 200:
raise errors.PluginInstallerError(f"下载源码失败: {resp.text}")
zip_resp = await resp.read()
if await aiofiles_os.path.exists("temp/" + target_path):
await aioshutil.rmtree("temp/" + target_path)
if await aiofiles_os.path.exists(target_path):
await aioshutil.rmtree(target_path)
await aiofiles_os.makedirs("temp/" + target_path)
async with aiofiles.open("temp/" + target_path + "/source.zip", "wb") as f:
await f.write(zip_resp)
self.ap.logger.debug("解压中...")
task_context.trace("解压中...", "unzip-plugin-source-code")
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
zip_ref.extractall("temp/" + target_path)
await aiofiles_os.remove("temp/" + target_path + "/source.zip")
import glob
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
await aioshutil.copytree(unzip_dir, target_path + "/")
await aioshutil.rmtree(unzip_dir)
self.ap.logger.debug("源码下载完成。")
return repo[1] return repo[1]
async def install_requirements(self, path: str): async def install_requirements(self, path: str):
if os.path.exists(path + "/requirements.txt"): if os.path.exists(path + "/requirements.txt"):
pkgmgr.install_requirements(path + "/requirements.txt") pkgmgr.install_requirements(path + "/requirements.txt")
@ -101,7 +104,7 @@ class GitHubRepoInstaller(installer.PluginInstaller):
""" """
task_context.trace("下载插件源码...", "install-plugin") task_context.trace("下载插件源码...", "install-plugin")
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/") repo_label = await self.download_plugin_source_code(plugin_source, "plugins/", task_context)
task_context.trace("安装插件依赖...", "install-plugin") task_context.trace("安装插件依赖...", "install-plugin")

View File

@ -19,3 +19,5 @@ quart
sqlalchemy[asyncio] sqlalchemy[asyncio]
aiosqlite aiosqlite
quart-cors quart-cors
aiofiles
aioshutil