fix: In the output, the order of 'ta' is sometimes reversed as 'at'. #8015 (#8791)

This commit is contained in:
Wei-shun Bao 2024-10-15 16:24:29 +08:00 committed by GitHub
parent cd7ab6231f
commit fb32e5ca9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,6 +62,8 @@ class CotAgentOutputParser:
thought_str = "thought:" thought_str = "thought:"
thought_idx = 0 thought_idx = 0
last_character = ""
for response in llm_response: for response in llm_response:
if response.delta.usage: if response.delta.usage:
usage_dict["usage"] = response.delta.usage usage_dict["usage"] = response.delta.usage
@ -74,35 +76,38 @@ class CotAgentOutputParser:
while index < len(response): while index < len(response):
steps = 1 steps = 1
delta = response[index : index + steps] delta = response[index : index + steps]
last_character = response[index - 1] if index > 0 else "" yield_delta = False
if delta == "`": if delta == "`":
last_character = delta
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count += 1 code_block_delimiter_count += 1
else: else:
if not in_code_block: if not in_code_block:
if code_block_delimiter_count > 0: if code_block_delimiter_count > 0:
last_character = delta
yield code_block_cache yield code_block_cache
code_block_cache = "" code_block_cache = ""
else: else:
last_character = delta
code_block_cache += delta code_block_cache += delta
code_block_delimiter_count = 0 code_block_delimiter_count = 0
if not in_code_block and not in_json: if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0: if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in {"\n", " ", ""}: if last_character not in {"\n", " ", ""}:
yield_delta = True
else:
last_character = delta
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_idx = 0
index += steps index += steps
yield delta
continue continue
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ""
action_idx = 0
index += steps
continue
elif delta.lower() == action_str[action_idx] and action_idx > 0: elif delta.lower() == action_str[action_idx] and action_idx > 0:
last_character = delta
action_cache += delta action_cache += delta
action_idx += 1 action_idx += 1
if action_idx == len(action_str): if action_idx == len(action_str):
@ -112,24 +117,25 @@ class CotAgentOutputParser:
continue continue
else: else:
if action_cache: if action_cache:
last_character = delta
yield action_cache yield action_cache
action_cache = "" action_cache = ""
action_idx = 0 action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0: if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in {"\n", " ", ""}: if last_character not in {"\n", " ", ""}:
yield_delta = True
else:
last_character = delta
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_idx = 0
index += steps index += steps
yield delta
continue continue
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ""
thought_idx = 0
index += steps
continue
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0: elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
last_character = delta
thought_cache += delta thought_cache += delta
thought_idx += 1 thought_idx += 1
if thought_idx == len(thought_str): if thought_idx == len(thought_str):
@ -139,12 +145,20 @@ class CotAgentOutputParser:
continue continue
else: else:
if thought_cache: if thought_cache:
last_character = delta
yield thought_cache yield thought_cache
thought_cache = "" thought_cache = ""
thought_idx = 0 thought_idx = 0
if yield_delta:
index += steps
last_character = delta
yield delta
continue
if code_block_delimiter_count == 3: if code_block_delimiter_count == 3:
if in_code_block: if in_code_block:
last_character = delta
yield from extra_json_from_code_block(code_block_cache) yield from extra_json_from_code_block(code_block_cache)
code_block_cache = "" code_block_cache = ""
@ -156,8 +170,10 @@ class CotAgentOutputParser:
if delta == "{": if delta == "{":
json_quote_count += 1 json_quote_count += 1
in_json = True in_json = True
last_character = delta
json_cache += delta json_cache += delta
elif delta == "}": elif delta == "}":
last_character = delta
json_cache += delta json_cache += delta
if json_quote_count > 0: if json_quote_count > 0:
json_quote_count -= 1 json_quote_count -= 1
@ -168,16 +184,19 @@ class CotAgentOutputParser:
continue continue
else: else:
if in_json: if in_json:
last_character = delta
json_cache += delta json_cache += delta
if got_json: if got_json:
got_json = False got_json = False
last_character = delta
yield parse_action(json_cache) yield parse_action(json_cache)
json_cache = "" json_cache = ""
json_quote_count = 0 json_quote_count = 0
in_json = False in_json = False
if not in_code_block and not in_json: if not in_code_block and not in_json:
last_character = delta
yield delta.replace("`", "") yield delta.replace("`", "")
index += steps index += steps