From 8811677154869dc70b49e83a577dc495d3ae04b8 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 15 Apr 2024 00:23:42 +0800 Subject: [PATCH] feat: remove langchain from output parsers (#3473) --- api/core/entities/message_entities.py | 108 +----------------- api/core/llm_generator/llm_generator.py | 3 +- .../llm_generator/output_parser/errors.py | 2 + .../output_parser/rule_config_generator.py | 5 +- .../suggested_questions_after_answer.py | 4 +- .../builtin/twilio/tools/send_message.py | 11 -- api/libs/json_in_md_parser.py | 2 +- 7 files changed, 8 insertions(+), 127 deletions(-) create mode 100644 api/core/llm_generator/output_parser/errors.py diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 6f767aafc7..b9406e24c4 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -1,19 +1,8 @@ import enum -from typing import Any, cast +from typing import Any -from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage from pydantic import BaseModel -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) - class PromptMessageFileType(enum.Enum): IMAGE = 'image' @@ -38,98 +27,3 @@ class ImagePromptMessageFile(PromptMessageFile): type: PromptMessageFileType = PromptMessageFileType.IMAGE detail: DETAIL = DETAIL.LOW - - -class LCHumanMessageWithFiles(HumanMessage): - # content: Union[str, list[Union[str, Dict]]] - content: str - files: list[PromptMessageFile] - - -def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]: - prompt_messages = [] - for message in messages: - if isinstance(message, HumanMessage): - if isinstance(message, LCHumanMessageWithFiles): - file_prompt_message_contents = [] - for file in message.files: - if file.type == PromptMessageFileType.IMAGE: - file = cast(ImagePromptMessageFile, file) - file_prompt_message_contents.append(ImagePromptMessageContent( - data=file.data, - detail=ImagePromptMessageContent.DETAIL.HIGH - if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW - )) - - prompt_message_contents = [TextPromptMessageContent(data=message.content)] - prompt_message_contents.extend(file_prompt_message_contents) - - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) - else: - prompt_messages.append(UserPromptMessage(content=message.content)) - elif isinstance(message, AIMessage): - message_kwargs = { - 'content': message.content - } - - if 'function_call' in message.additional_kwargs: - message_kwargs['tool_calls'] = [ - AssistantPromptMessage.ToolCall( - id=message.additional_kwargs['function_call']['id'], - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=message.additional_kwargs['function_call']['name'], - arguments=message.additional_kwargs['function_call']['arguments'] - ) - ) - ] - - prompt_messages.append(AssistantPromptMessage(**message_kwargs)) - elif isinstance(message, SystemMessage): - prompt_messages.append(SystemPromptMessage(content=message.content)) - elif isinstance(message, FunctionMessage): - prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name)) - - return prompt_messages - - -def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]: - messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message, UserPromptMessage): - if isinstance(prompt_message.content, str): - messages.append(HumanMessage(content=prompt_message.content)) - else: - message_contents = [] - for content in prompt_message.content: - if isinstance(content, TextPromptMessageContent): - message_contents.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - message_contents.append({ - 'type': 'image', - 'data': content.data, - 'detail': content.detail.value - }) - - messages.append(HumanMessage(content=message_contents)) - elif isinstance(prompt_message, AssistantPromptMessage): - message_kwargs = { - 'content': prompt_message.content - } - - if prompt_message.tool_calls: - message_kwargs['additional_kwargs'] = { - 'function_call': { - 'id': prompt_message.tool_calls[0].id, - 'name': prompt_message.tool_calls[0].function.name, - 'arguments': prompt_message.tool_calls[0].function.arguments - } - } - - messages.append(AIMessage(**message_kwargs)) - elif isinstance(prompt_message, SystemPromptMessage): - messages.append(SystemMessage(content=prompt_message.content)) - elif isinstance(prompt_message, ToolPromptMessage): - messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content)) - - return messages diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 2fc60daab4..14de8649c6 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -1,8 +1,7 @@ import json import logging -from langchain.schema import OutputParserException - +from core.llm_generator.output_parser.errors import OutputParserException from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py new file mode 100644 index 0000000000..6a60f8de80 --- /dev/null +++ b/api/core/llm_generator/output_parser/errors.py @@ -0,0 +1,2 @@ +class OutputParserException(Exception): + pass diff --git a/api/core/llm_generator/output_parser/rule_config_generator.py b/api/core/llm_generator/output_parser/rule_config_generator.py index b95653f69c..f6d4bcf11a 100644 --- a/api/core/llm_generator/output_parser/rule_config_generator.py +++ b/api/core/llm_generator/output_parser/rule_config_generator.py @@ -1,12 +1,11 @@ from typing import Any -from langchain.schema import BaseOutputParser, OutputParserException - +from core.llm_generator.output_parser.errors import OutputParserException from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown -class RuleConfigGeneratorOutputParser(BaseOutputParser): +class RuleConfigGeneratorOutputParser: def get_format_instructions(self) -> str: return RULE_CONFIG_GENERATE_TEMPLATE diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index ad30bcfa07..3f046c68fc 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -2,12 +2,10 @@ import json import re from typing import Any -from langchain.schema import BaseOutputParser - from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT -class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): +class SuggestedQuestionsAfterAnswerOutputParser: def get_format_instructions(self) -> str: return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT diff --git a/api/core/tools/provider/builtin/twilio/tools/send_message.py b/api/core/tools/provider/builtin/twilio/tools/send_message.py index 33c5d62e49..24502a3565 100644 --- a/api/core/tools/provider/builtin/twilio/tools/send_message.py +++ b/api/core/tools/provider/builtin/twilio/tools/send_message.py @@ -13,17 +13,6 @@ class TwilioAPIWrapper(BaseModel): and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and ``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as named parameters to the constructor. - - Example: - .. code-block:: python - - from langchain.utilities.twilio import TwilioAPIWrapper - twilio = TwilioAPIWrapper( - account_sid="ACxxx", - auth_token="xxx", - from_number="+10123456789" - ) - twilio.run('test', '+12484345508') """ client: Any #: :meta private: diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 5569519524..2cf023a399 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -1,6 +1,6 @@ import json -from langchain.schema import OutputParserException +from core.llm_generator.output_parser.errors import OutputParserException def parse_json_markdown(json_string: str) -> dict: