diff --git a/pkg/provider/sysprompt/loader.py b/pkg/provider/sysprompt/loader.py index ca9e873..9e0a614 100644 --- a/pkg/provider/sysprompt/loader.py +++ b/pkg/provider/sysprompt/loader.py @@ -1,13 +1,27 @@ from __future__ import annotations import abc +import typing from ...core import app from . import entities +preregistered_loaders: list[typing.Type[PromptLoader]] = [] + +def loader_class(name: str): + + def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]: + cls.name = name + preregistered_loaders.append(cls) + return cls + + return decorator + + class PromptLoader(metaclass=abc.ABCMeta): """Prompt加载器抽象类 """ + name: str ap: app.Application diff --git a/pkg/provider/sysprompt/loaders/scenario.py b/pkg/provider/sysprompt/loaders/scenario.py index a559ff7..9c19d96 100644 --- a/pkg/provider/sysprompt/loaders/scenario.py +++ b/pkg/provider/sysprompt/loaders/scenario.py @@ -8,6 +8,7 @@ from .. import entities from ....provider import entities as llm_entities +@loader.loader_class("full_scenario") class ScenarioPromptLoader(loader.PromptLoader): """加载scenario目录下的json""" diff --git a/pkg/provider/sysprompt/loaders/single.py b/pkg/provider/sysprompt/loaders/single.py index 57e06ed..3ac9c26 100644 --- a/pkg/provider/sysprompt/loaders/single.py +++ b/pkg/provider/sysprompt/loaders/single.py @@ -6,6 +6,7 @@ from .. import entities from ....provider import entities as llm_entities +@loader.loader_class("normal") class SingleSystemPromptLoader(loader.PromptLoader): """配置文件中的单条system prompt的prompt加载器 """ diff --git a/pkg/provider/sysprompt/sysprompt.py b/pkg/provider/sysprompt/sysprompt.py index eb89e8a..61c598e 100644 --- a/pkg/provider/sysprompt/sysprompt.py +++ b/pkg/provider/sysprompt/sysprompt.py @@ -20,12 +20,14 @@ class PromptManager: async def initialize(self): - loader_map = { - "normal": single.SingleSystemPromptLoader, - "full_scenario": scenario.ScenarioPromptLoader - } + mode_name = self.ap.provider_cfg.data['prompt-mode'] - loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']] + for loader_cls in loader.preregistered_loaders: + if loader_cls.name == mode_name: + loader_cls = loader_cls + break + else: + raise ValueError(f'未知的 Prompt 加载器: {mode_name}') self.loader_inst: loader.PromptLoader = loader_cls(self.ap)