From d5d8b98d8245a968316479af17c0e41a7a453104 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 7 May 2024 13:49:45 +0800 Subject: [PATCH] feat: support openai stream usage (#4140) --- .../model_providers/openai/llm/llm.py | 105 +++++++++++++----- api/requirements.txt | 2 +- 2 files changed, 77 insertions(+), 30 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index b7db39376c..69afabadb3 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -378,6 +378,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if user: extra_model_kwargs['user'] = user + if stream: + extra_model_kwargs['stream_options'] = { + "include_usage": True + } + # text completion model response = client.completions.create( prompt=prompt_messages[0].content, @@ -446,8 +451,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): :return: llm response chunk generator result """ full_text = '' + prompt_tokens = 0 + completion_tokens = 0 + + final_chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=''), + ) + ) + for chunk in response: if len(chunk.choices) == 0: + if chunk.usage: + # calculate num tokens + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens continue delta = chunk.choices[0] @@ -464,20 +485,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): full_text += text if delta.finish_reason is not None: - # calculate num tokens - if chunk.usage: - # transform usage - prompt_tokens = chunk.usage.prompt_tokens - completion_tokens = chunk.usage.completion_tokens - else: - # calculate num tokens - prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) - completion_tokens = self._num_tokens_from_string(model, full_text) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( + final_chunk = LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, @@ -485,7 +493,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage ) ) else: @@ -499,6 +506,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ) ) + if not prompt_tokens: + prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content) + + if not completion_tokens: + completion_tokens = self._num_tokens_from_string(model, full_text) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + + final_chunk.delta.usage = usage + + yield final_chunk + def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, @@ -531,6 +551,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): model_parameters["response_format"] = response_format + extra_model_kwargs = {} if tools: @@ -547,6 +568,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if user: extra_model_kwargs['user'] = user + if stream: + extra_model_kwargs['stream_options'] = { + 'include_usage': True + } + # clear illegal prompt messages prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) @@ -630,8 +656,24 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ full_assistant_content = '' delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None + prompt_tokens = 0 + completion_tokens = 0 + final_tool_calls = [] + final_chunk = LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=''), + ) + ) + for chunk in response: if len(chunk.choices) == 0: + if chunk.usage: + # calculate num tokens + prompt_tokens = chunk.usage.prompt_tokens + completion_tokens = chunk.usage.completion_tokens continue delta = chunk.choices[0] @@ -667,6 +709,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) function_call = self._extract_response_function_call(assistant_message_function_call) tool_calls = [function_call] if function_call else [] + if tool_calls: + final_tool_calls.extend(tool_calls) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( @@ -677,19 +721,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): full_assistant_content += delta.delta.content if delta.delta.content else '' if has_finish_reason: - # calculate num tokens - prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) - - full_assistant_prompt_message = AssistantPromptMessage( - content=full_assistant_content, - tool_calls=tool_calls - ) - completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) - - # transform usage - usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) - - yield LLMResultChunk( + final_chunk = LLMResultChunk( model=chunk.model, prompt_messages=prompt_messages, system_fingerprint=chunk.system_fingerprint, @@ -697,7 +729,6 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): index=delta.index, message=assistant_prompt_message, finish_reason=delta.finish_reason, - usage=usage ) ) else: @@ -711,6 +742,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): ) ) + if not prompt_tokens: + prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) + + if not completion_tokens: + full_assistant_prompt_message = AssistantPromptMessage( + content=full_assistant_content, + tool_calls=final_tool_calls + ) + completion_tokens = self._num_tokens_from_messages(model, [full_assistant_prompt_message]) + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + final_chunk.delta.usage = usage + + yield final_chunk + def _extract_response_tool_calls(self, response_tool_calls: list[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]) \ -> list[AssistantPromptMessage.ToolCall]: diff --git a/api/requirements.txt b/api/requirements.txt index 9d79afa4ec..e2c430c9d6 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -9,7 +9,7 @@ flask-restful~=0.3.10 flask-cors~=4.0.0 gunicorn~=22.0.0 gevent~=23.9.1 -openai~=1.13.3 +openai~=1.26.0 tiktoken~=0.6.0 psycopg2-binary~=2.9.6 pycryptodome==3.19.1