Merge pull request #529 from RockChinQ/feat-funcs-called-args

[Feat] NormalMessageResponded添加func_called参数
This commit is contained in:
Junyan Qin 2023-08-02 18:02:48 +08:00 committed by GitHub
commit 819339142e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 7 deletions

View File

@ -194,14 +194,14 @@ class Session:
# 请求回复 # 请求回复
# 这个函数是阻塞的 # 这个函数是阻塞的
def append(self, text: str=None) -> tuple[str, str]: def append(self, text: str=None) -> tuple[str, str, list[str]]:
"""向session中添加一条消息返回接口回复 """向session中添加一条消息返回接口回复
Args: Args:
text (str): 用户消息 text (str): 用户消息
Returns: Returns:
tuple[str, str]: (接口回复, finish_reason) tuple[str, str]: (接口回复, finish_reason, 已调用的函数列表)
""" """
self.last_interact_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time())
@ -216,7 +216,7 @@ class Session:
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args) event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
if event.is_prevented_default(): if event.is_prevented_default():
return None, None return None, None, None
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length max_length = config.prompt_submit_length
@ -253,6 +253,8 @@ class Session:
finish_reason: str = "" finish_reason: str = ""
funcs = []
for resp in pkg.utils.context.get_openai_manager().request_completion(prompts): for resp in pkg.utils.context.get_openai_manager().request_completion(prompts):
finish_reason = resp['choices'][0]['finish_reason'] finish_reason = resp['choices'][0]['finish_reason']
@ -288,10 +290,12 @@ class Session:
# ) # )
# total_tokens += resp['usage']['total_tokens'] # total_tokens += resp['usage']['total_tokens']
funcs.append(
resp['choices'][0]['message']['function_name']
)
pass pass
# 向API请求补全 # 向API请求补全
# message, total_token = pkg.utils.context.get_openai_manager().request_completion( # message, total_token = pkg.utils.context.get_openai_manager().request_completion(
# prompts, # prompts,
@ -317,7 +321,7 @@ class Session:
self.just_switched_to_exist_session = False self.just_switched_to_exist_session = False
self.set_ongoing() self.set_ongoing()
return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason return res_ans if res_ans[0] != '\n' else res_ans[1:], finish_reason, funcs
# 删除上一回合并返回上一回合的问题 # 删除上一回合并返回上一回合的问题
def undo(self) -> str: def undo(self) -> str:

View File

@ -89,6 +89,7 @@ NormalMessageResponded = "normal_message_responded"
prefix: str 回复文字消息的前缀 prefix: str 回复文字消息的前缀
response_text: str 响应文本 response_text: str 响应文本
finish_reason: str 响应结束原因 finish_reason: str 响应结束原因
funcs_called: list[str] 此次响应中调用的函数列表
returns (optional): returns (optional):
prefix: str 修改后的回复文字消息的前缀 prefix: str 修改后的回复文字消息的前缀

View File

@ -20,7 +20,7 @@ class ContinueCommand(AbstractCommandNode):
session = pkg.openai.session.get_session(session_name) session = pkg.openai.session.get_session(session_name)
text, _ = session.append() text, _, _ = session.append()
reply = [text] reply = [text]

View File

@ -40,7 +40,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
try: try:
prefix = "[GPT]" if config.show_prefix else "" prefix = "[GPT]" if config.show_prefix else ""
text, finish_reason = session.append(text_message) text, finish_reason, funcs = session.append(text_message)
# 触发插件事件 # 触发插件事件
args = { args = {
@ -51,6 +51,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
"prefix": prefix, "prefix": prefix,
"response_text": text, "response_text": text,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"funcs_called": funcs,
} }
event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args) event = pkg.plugin.host.emit(plugin_models.NormalMessageResponded, **args)
@ -63,6 +64,7 @@ def process_normal_message(text_message: str, mgr, config, launcher_type: str,
if not event.is_prevented_default(): if not event.is_prevented_default():
reply = [prefix + text] reply = [prefix + text]
except openai.error.APIConnectionError as e: except openai.error.APIConnectionError as e:
err_msg = str(e) err_msg = str(e)
if err_msg.__contains__('Error communicating with OpenAI'): if err_msg.__contains__('Error communicating with OpenAI'):