mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: prompt 加载器的扩展性
This commit is contained in:
parent
8c6ce1f030
commit
b9fa11c0c3
|
@ -1,13 +1,27 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
|
import typing
|
||||||
|
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
|
||||||
|
|
||||||
|
preregistered_loaders: list[typing.Type[PromptLoader]] = []
|
||||||
|
|
||||||
|
def loader_class(name: str):
|
||||||
|
|
||||||
|
def decorator(cls: typing.Type[PromptLoader]) -> typing.Type[PromptLoader]:
|
||||||
|
cls.name = name
|
||||||
|
preregistered_loaders.append(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class PromptLoader(metaclass=abc.ABCMeta):
|
class PromptLoader(metaclass=abc.ABCMeta):
|
||||||
"""Prompt加载器抽象类
|
"""Prompt加载器抽象类
|
||||||
"""
|
"""
|
||||||
|
name: str
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from .. import entities
|
||||||
from ....provider import entities as llm_entities
|
from ....provider import entities as llm_entities
|
||||||
|
|
||||||
|
|
||||||
|
@loader.loader_class("full_scenario")
|
||||||
class ScenarioPromptLoader(loader.PromptLoader):
|
class ScenarioPromptLoader(loader.PromptLoader):
|
||||||
"""加载scenario目录下的json"""
|
"""加载scenario目录下的json"""
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from .. import entities
|
||||||
from ....provider import entities as llm_entities
|
from ....provider import entities as llm_entities
|
||||||
|
|
||||||
|
|
||||||
|
@loader.loader_class("normal")
|
||||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||||
"""配置文件中的单条system prompt的prompt加载器
|
"""配置文件中的单条system prompt的prompt加载器
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -20,12 +20,14 @@ class PromptManager:
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
|
|
||||||
loader_map = {
|
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
||||||
"normal": single.SingleSystemPromptLoader,
|
|
||||||
"full_scenario": scenario.ScenarioPromptLoader
|
|
||||||
}
|
|
||||||
|
|
||||||
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']]
|
for loader_cls in loader.preregistered_loaders:
|
||||||
|
if loader_cls.name == mode_name:
|
||||||
|
loader_cls = loader_cls
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
||||||
|
|
||||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user