From 792366e221cbb70c7ec429fc0bb53f4363034baa Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sun, 5 Mar 2023 11:56:40 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=9F=BA=E4=BA=8E?= =?UTF-8?q?=E8=AF=AD=E4=B9=89=E5=8C=96=E7=89=88=E6=9C=AC=E7=9A=84=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- main.py | 6 ++- pkg/utils/updater.py | 109 +++++++++++++++++++++++++++++++++---------- requirements.txt | 1 + 4 files changed, 93 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index c992502..362973b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ plugins/ prompts/ logs/ sensitive.json -temp/ \ No newline at end of file +temp/ +current_tag \ No newline at end of file diff --git a/main.py b/main.py index dae29ae..c4ff6ab 100644 --- a/main.py +++ b/main.py @@ -314,10 +314,14 @@ if __name__ == '__main__': if not os.path.exists('banlist.py'): shutil.copy('banlist-template.py', 'banlist.py') - # 检查是否有sensitive.json, + # 检查是否有sensitive.json if not os.path.exists("sensitive.json"): shutil.copy("sensitive-template.json", "sensitive.json") + # 检查temp目录 + if not os.path.exists("temp/"): + os.mkdir("temp/") + if len(sys.argv) > 1 and sys.argv[1] == 'init_db': init_db() sys.exit(0) diff --git a/pkg/utils/updater.py b/pkg/utils/updater.py index 7c55dd7..d053503 100644 --- a/pkg/utils/updater.py +++ b/pkg/utils/updater.py @@ -1,4 +1,9 @@ import datetime +import logging +import os.path + +import requests +import json import pkg.utils.context @@ -29,33 +34,85 @@ def pull_latest(repo_path: str) -> bool: def update_all() -> bool: - """使用dulwich更新源码""" - check_dulwich_closure() - import dulwich - try: - before_commit_id = get_current_commit_id() - from dulwich import porcelain - repo = porcelain.open_repo('.') - porcelain.pull(repo) + """检查更新并下载源码""" + current_tag = "v0.1.0" + if os.path.exists("current_tag"): + with open("current_tag", "r") as f: + current_tag = f.read() - change_log = "" + rls_list_resp = requests.get( + url="https://api.github.com/repos/RockChinQ/QChatGPT/releases" + ) - for entry in repo.get_walker(): - if str(entry.commit.id)[2:-1] == before_commit_id: - break - tz = datetime.timezone(datetime.timedelta(hours=entry.commit.commit_timezone // 3600)) - dt = datetime.datetime.fromtimestamp(entry.commit.commit_time, tz) - change_log += dt.strftime('%Y-%m-%d %H:%M:%S') + " [" + str(entry.commit.message, encoding="utf-8").strip()+"]\n" + rls_list = rls_list_resp.json() - if change_log != "": - pkg.utils.context.get_qqbot_manager().notify_admin("代码拉取完成,更新内容如下:\n"+change_log) - return True - else: - return False - except ModuleNotFoundError: - raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77") - except dulwich.porcelain.DivergedBranches: - raise Exception("分支不一致,自动更新仅支持master分支,请手动更新(https://github.com/RockChinQ/QChatGPT/issues/76)") + latest_rls = {} + rls_notes = [] + for rls in rls_list: + rls_notes.append(rls['name']) # 使用发行名称作为note + if rls['tag_name'] == current_tag: + break + + if latest_rls == {}: + latest_rls = rls + print(rls_notes) + if latest_rls == {}: # 没有新版本 + return False + + # 下载最新版本的zip到temp目录 + logging.info("开始下载最新版本: {}".format(latest_rls['zipball_url'])) + zip_url = latest_rls['zipball_url'] + zip_resp = requests.get(url=zip_url) + zip_data = zip_resp.content + + # 检查temp/updater目录 + if not os.path.exists("temp"): + os.mkdir("temp") + if not os.path.exists("temp/updater"): + os.mkdir("temp/updater") + with open("temp/updater/{}.zip".format(latest_rls['tag_name']), "wb") as f: + f.write(zip_data) + + logging.info("下载最新版本完成: {}".format("temp/updater/{}.zip".format(latest_rls['tag_name']))) + + # 解压zip到temp/updater// + import zipfile + # 检查目标文件夹 + if os.path.exists("temp/updater/{}".format(latest_rls['tag_name'])): + import shutil + shutil.rmtree("temp/updater/{}".format(latest_rls['tag_name'])) + os.mkdir("temp/updater/{}".format(latest_rls['tag_name'])) + with zipfile.ZipFile("temp/updater/{}.zip".format(latest_rls['tag_name']), 'r') as zip_ref: + zip_ref.extractall("temp/updater/{}".format(latest_rls['tag_name'])) + + # 覆盖源码 + source_root = "" + # 找到temp/updater//中的第一个子目录路径 + for root, dirs, files in os.walk("temp/updater/{}".format(latest_rls['tag_name'])): + if root != "temp/updater/{}".format(latest_rls['tag_name']): + source_root = root + break + + # 覆盖源码 + import shutil + for root, dirs, files in os.walk(source_root): + # 覆盖所有子文件子目录 + for file in files: + src = os.path.join(root, file) + dst = src.replace(source_root, ".") + if os.path.exists(dst): + os.remove(dst) + shutil.copy(src, dst) + + # 把current_tag写入文件 + current_tag = latest_rls['tag_name'] + with open("current_tag", "w") as f: + f.write(current_tag) + + # 通知管理员 + import pkg.utils.context + pkg.utils.context.get_qqbot_manager().notify_admin("已更新到最新版本: {}\n更新日志:\n{}".format(current_tag, "\n".join(rls_notes))) + return True def is_repo(path: str) -> bool: @@ -144,3 +201,7 @@ def is_new_version_available() -> bool: latest_commit_id = str(fetch_res[b'HEAD'])[2:-1] return current_commit_id != latest_commit_id + + +if __name__ == "__main__": + update_all() diff --git a/requirements.txt b/requirements.txt index 3bae775..a61c1a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ requests~=2.28.1 openai~=0.27.0 +dulwich~=0.21.3 colorlog~=6.6.0 yiri-mirai~=0.2.6.1 websockets~=10.4