diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e73f47b8e8..102f54eff6 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -224,7 +224,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value else: raise ValueError(f"Unknown completion type {credentials['completion_type']}") - + return entity # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. @@ -343,32 +343,44 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): ) ) - for chunk in response.iter_lines(decode_unicode=True, delimiter='\n\n'): + # delimiter for stream response, need unicode_escape + import codecs + delimiter = credentials.get("stream_mode_delimiter", "\n\n") + delimiter = codecs.decode(delimiter, "unicode_escape") + + for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: decoded_chunk = chunk.strip().lstrip('data: ').lstrip() - chunk_json = None try: chunk_json = json.loads(decoded_chunk) # stream ended except json.JSONDecodeError as e: + logger.error(f"decoded_chunk error,delimiter={delimiter},decoded_chunk={decoded_chunk}") yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), finish_reason="Non-JSON encountered." ) break - if not chunk_json or len(chunk_json['choices']) == 0: continue choice = chunk_json['choices'][0] + finish_reason = chunk_json['choices'][0].get('finish_reason') chunk_index += 1 if 'delta' in choice: delta = choice['delta'] if delta.get('content') is None or delta.get('content') == '': - continue + if finish_reason is not None: + yield create_final_llm_result_chunk( + index=chunk_index, + message=AssistantPromptMessage(content=choice.get('text', '')), + finish_reason=finish_reason + ) + else: + continue assistant_message_tool_calls = delta.get('tool_calls', None) # assistant_message_function_call = delta.delta.function_call @@ -387,24 +399,22 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): full_assistant_content += delta.get('content', '') elif 'text' in choice: - if choice.get('text') is None or choice.get('text') == '': + choice_text = choice.get('text', '') + if choice_text == '': continue # transform assistant message to prompt message - assistant_prompt_message = AssistantPromptMessage( - content=choice.get('text', '') - ) - - full_assistant_content += choice.get('text', '') + assistant_prompt_message = AssistantPromptMessage(content=choice_text) + full_assistant_content += choice_text else: continue # check payload indicator for completion - if chunk_json['choices'][0].get('finish_reason') is not None: + if finish_reason is not None: yield create_final_llm_result_chunk( index=chunk_index, message=assistant_prompt_message, - finish_reason=chunk_json['choices'][0]['finish_reason'] + finish_reason=finish_reason ) else: yield LLMResultChunk( diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 088738c0ff..213d334fe8 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -75,3 +75,12 @@ model_credential_schema: value: llm default: '4096' type: text-input + - variable: stream_mode_delimiter + label: + zh_Hans: 流模式返回结果的分隔符 + en_US: Delimiter for streaming results + show_on: + - variable: __model_type + value: llm + default: '\n\n' + type: text-input diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index 9e94e562e9..c3cb5a481c 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -12,6 +12,7 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI Using Together.ai's OpenAI-compatible API as testing endpoint """ + def test_validate_credentials(): model = OAIAPICompatLargeLanguageModel() @@ -34,6 +35,7 @@ def test_validate_credentials(): } ) + def test_invoke_model(): model = OAIAPICompatLargeLanguageModel() @@ -65,9 +67,47 @@ def test_invoke_model(): assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = OAIAPICompatLargeLanguageModel() + response = model.invoke( + model='mistralai/Mixtral-8x7B-Instruct-v0.1', + credentials={ + 'api_key': os.environ.get('TOGETHER_API_KEY'), + 'endpoint_url': 'https://api.together.xyz/v1/', + 'mode': 'chat', + 'stream_mode_delimiter': '\\n\\n' + }, + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Who are you?' + ) + ], + model_parameters={ + 'temperature': 1.0, + 'top_k': 2, + 'top_p': 0.5, + }, + stop=['How'], + stream=True, + user="abc-123" + ) + + assert isinstance(response, Generator) + + for chunk in response: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + + +def test_invoke_stream_model_without_delimiter(): + model = OAIAPICompatLargeLanguageModel() + response = model.invoke( model='mistralai/Mixtral-8x7B-Instruct-v0.1', credentials={ @@ -100,6 +140,7 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) + # using OpenAI's ChatGPT-3.5 as testing endpoint def test_invoke_chat_model_with_tools(): model = OAIAPICompatLargeLanguageModel() @@ -126,22 +167,22 @@ def test_invoke_chat_model_with_tools(): parameters={ "type": "object", "properties": { - "location": { - "type": "string", - "description": "The city and state e.g. San Francisco, CA" - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit" - ] - } + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ] + } }, "required": [ - "location" + "location" ] - } + } ), ], model_parameters={ @@ -156,6 +197,7 @@ def test_invoke_chat_model_with_tools(): assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 + def test_get_num_tokens(): model = OAIAPICompatLargeLanguageModel()