From 5e97eb18408c268e091027d919ca6a2f369c4e4c Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 11 Jan 2024 17:34:58 +0800 Subject: [PATCH] fix: azure openai stream response usage missing (#1998) --- api/core/app_runner/app_runner.py | 3 + .../model_providers/azure_openai/llm/llm.py | 69 ++++++++++--------- .../model_runtime/azure_openai/test_llm.py | 1 - 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index a155de09b3..fe3c08f03f 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -257,6 +257,9 @@ class AppRunner: if not usage and result.delta.usage: usage = result.delta.usage + if not usage: + usage = LLMUsage.empty_usage() + llm_result = LLMResult( model=model, prompt_messages=prompt_messages, diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 2e4cd069ab..72965a7613 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -322,8 +322,11 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): response: Stream[ChatCompletionChunk], prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> Generator: - + index = 0 full_assistant_content = '' + real_model = model + system_fingerprint = None + completion = '' for chunk in response: if len(chunk.choices) == 0: continue @@ -349,40 +352,44 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): full_assistant_content += delta.delta.content if delta.delta.content else '' - if delta.finish_reason is not None: - # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) + real_model = chunk.model + system_fingerprint = chunk.system_fingerprint + completion += delta.delta.content if delta.delta.content else '' - full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=tool_calls + yield LLMResultChunk( + model=real_model, + prompt_messages=prompt_messages, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=index, + message=assistant_prompt_message, ) - completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) + ) - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + index += 0 - yield LLMResultChunk( - model=chunk.model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - finish_reason=delta.finish_reason, - usage=usage - ) - ) - else: - yield LLMResultChunk( - model=chunk.model, - prompt_messages=prompt_messages, - system_fingerprint=chunk.system_fingerprint, - delta=LLMResultChunkDelta( - index=delta.index, - message=assistant_prompt_message, - ) - ) + # calculate num tokens + prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools) + + full_assistant_prompt_message = AssistantPromptMessage( + content=completion + ) + completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + yield LLMResultChunk( + model=real_model, + prompt_messages=prompt_messages, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=index, + message=AssistantPromptMessage(content=''), + finish_reason='stop', + usage=usage + ) + ) @staticmethod def _extract_response_tool_calls(response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index e74465283e..90a81b1d97 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -190,7 +190,6 @@ def test_invoke_stream_chat_model(setup_openai_mock): assert isinstance(chunk, LLMResultChunk) assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) - assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True if chunk.delta.finish_reason is not None: assert chunk.delta.usage is not None assert chunk.delta.usage.completion_tokens > 0