enhance: claude stream tool call (#4469)

This commit is contained in:
Yeuoly 2024-05-17 12:43:58 +08:00 committed by GitHub
parent 083ef2e6fc
commit 091fba74cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 34 additions and 3 deletions

View File

@ -6,6 +6,7 @@ features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -6,6 +6,7 @@ features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -6,6 +6,7 @@ features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000

View File

@ -324,8 +324,30 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
output_tokens = 0
finish_reason = None
index = 0
tool_calls: list[AssistantPromptMessage.ToolCall] = []
for chunk in response:
if isinstance(chunk, MessageStartEvent):
if hasattr(chunk, 'content_block'):
content_block = chunk.content_block
if isinstance(content_block, dict):
if content_block.get('type') == 'tool_use':
tool_call = AssistantPromptMessage.ToolCall(
id=content_block.get('id'),
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=content_block.get('name'),
arguments=''
)
)
tool_calls.append(tool_call)
elif hasattr(chunk, 'delta'):
delta = chunk.delta
if isinstance(delta, dict) and len(tool_calls) > 0:
if delta.get('type') == 'input_json_delta':
tool_calls[-1].function.arguments += delta.get('partial_json', '')
elif chunk.message:
return_model = chunk.message.model
input_tokens = chunk.message.usage.input_tokens
elif isinstance(chunk, MessageDeltaEvent):
@ -335,13 +357,19 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# transform usage
usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
# transform empty tool call arguments to {}
for tool_call in tool_calls:
if not tool_call.function.arguments:
tool_call.function.arguments = '{}'
yield LLMResultChunk(
model=return_model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index + 1,
message=AssistantPromptMessage(
content=''
content='',
tool_calls=tool_calls
),
finish_reason=finish_reason,
usage=usage