feat: prompt 加载器的扩展性

This commit is contained in:
Junyan Qin 2024-03-12 16:22:07 +00:00
parent 8c6ce1f030
commit b9fa11c0c3
4 changed files with 23 additions and 5 deletions

View File

@ -1,13 +1,27 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from . import entities
preregistered_loaders: list[typing.Type[PromptLoader]] = []
def loader_class(name: str):
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
cls.name = name
preregistered_loaders.append(cls)
return cls
return decorator
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
name: str
ap: app.Application

View File

@ -8,6 +8,7 @@ from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("full_scenario")
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""

View File

@ -6,6 +6,7 @@ from .. import entities
from ....provider import entities as llm_entities
@loader.loader_class("normal")
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""

View File

@ -20,12 +20,14 @@ class PromptManager:
async def initialize(self):
loader_map = {
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
mode_name = self.ap.provider_cfg.data['prompt-mode']
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']]
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_cls = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)