diff --git a/config-template.py b/config-template.py index 59e83a0..987e63d 100644 --- a/config-template.py +++ b/config-template.py @@ -70,6 +70,11 @@ openai_config = { "reverse_proxy": None } +# api-key切换策略 +# active:每次请求时都会切换api-key +# passive:仅当api-key超额时才会切换api-key +switch_strategy = "active" + # [必需] 管理员QQ号,用于接收报错等通知及执行管理员级别指令 # 支持多个管理员,可以使用list形式设置,例如: # admin_qq = [12345678, 87654321] diff --git a/override-all.json b/override-all.json index 6fed8cf..e46e1c9 100644 --- a/override-all.json +++ b/override-all.json @@ -21,6 +21,7 @@ "http_proxy": null, "reverse_proxy": null }, + "switch_strategy": "active", "admin_qq": 0, "default_prompt": { "default": "如果我之后想获取帮助,请你说“输入!help获取帮助”" diff --git a/pkg/openai/api/model.py b/pkg/openai/api/model.py index 2edacc9..58f3e3f 100644 --- a/pkg/openai/api/model.py +++ b/pkg/openai/api/model.py @@ -13,8 +13,15 @@ class RequestBase: def __init__(self, *args, **kwargs): raise NotImplementedError + def _next_key(self): + import pkg.utils.context as context + switched, name = context.get_openai_manager().key_mgr.auto_switch() + logging.debug("切换api-key: switched={}, name={}".format(switched, name)) + openai.api_key = context.get_openai_manager().key_mgr.get_using_key() + def _req(self, **kwargs): """处理代理问题""" + import config ret: dict = {} exception: Exception = None @@ -25,6 +32,10 @@ class RequestBase: try: ret = await self.req_func(**kwargs) logging.debug("接口请求返回:%s", str(ret)) + + if config.switch_strategy == 'active': + self._next_key() + return ret except Exception as e: exception = e diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index 4428b0d..bed4433 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -54,7 +54,24 @@ class KeysManager: 是否切换成功, 切换后的api-key的别名 """ + index = 0 + for key_name in self.api_key: + if self.api_key[key_name] == self.using_key: + break + + index += 1 + + # 从当前key开始向后轮询 + start_index = index + index += 1 + if index >= len(self.api_key): + index = 0 + + while index != start_index: + + key_name = list(self.api_key.keys())[index] + if self.api_key[key_name] not in self.exceeded: self.using_key = self.api_key[key_name] @@ -69,10 +86,14 @@ class KeysManager: return True, key_name - self.using_key = list(self.api_key.values())[0] - logging.info("使用api-key:" + list(self.api_key.keys())[0]) + index += 1 + if index >= len(self.api_key): + index = 0 - return False, "" + self.using_key = list(self.api_key.values())[start_index] + logging.debug("使用api-key:" + list(self.api_key.keys())[start_index]) + + return False, list(self.api_key.keys())[start_index] def add(self, key_name, key): self.api_key[key_name] = key diff --git a/pkg/qqbot/cmds/funcs/func.py b/pkg/qqbot/cmds/funcs/func.py index f33efd1..93b3184 100644 --- a/pkg/qqbot/cmds/funcs/func.py +++ b/pkg/qqbot/cmds/funcs/func.py @@ -1,6 +1,8 @@ from ..aamgr import AbstractCommandNode, Context import logging +import json + @AbstractCommandNode.register( parent=None, @@ -19,6 +21,8 @@ class FuncCommand(AbstractCommandNode): reply_str = "当前已加载的内容函数:\n\n" + logging.debug("host.__callable_functions__: {}".format(json.dumps(host.__callable_functions__, indent=4))) + index = 1 for func in host.__callable_functions__: reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description'])