mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: prompt 加载器的扩展性
This commit is contained in:
parent
8c6ce1f030
commit
b9fa11c0c3
|
@ -1,13 +1,27 @@
|
|||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_loaders: list[typing.Type[PromptLoader]] = []
|
||||
|
||||
def loader_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
|
||||
cls.name = name
|
||||
preregistered_loaders.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from .. import entities
|
|||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("full_scenario")
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from .. import entities
|
|||
from ....provider import entities as llm_entities
|
||||
|
||||
|
||||
@loader.loader_class("normal")
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
|
|
@ -20,12 +20,14 @@ class PromptManager:
|
|||
|
||||
async def initialize(self):
|
||||
|
||||
loader_map = {
|
||||
"normal": single.SingleSystemPromptLoader,
|
||||
"full_scenario": scenario.ScenarioPromptLoader
|
||||
}
|
||||
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
||||
|
||||
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']]
|
||||
for loader_cls in loader.preregistered_loaders:
|
||||
if loader_cls.name == mode_name:
|
||||
loader_cls = loader_cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user