feat: remove langchain from output parsers (#3473)

This commit is contained in:
takatost 2024-04-15 00:23:42 +08:00 committed by GitHub
parent 12f1ce4794
commit 8811677154
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 8 additions and 127 deletions

View File

@ -1,19 +1,8 @@
import enum
from typing import Any, cast
from typing import Any
from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage
from pydantic import BaseModel
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
class PromptMessageFileType(enum.Enum):
IMAGE = 'image'
@ -38,98 +27,3 @@ class ImagePromptMessageFile(PromptMessageFile):
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW
class LCHumanMessageWithFiles(HumanMessage):
# content: Union[str, list[Union[str, Dict]]]
content: str
files: list[PromptMessageFile]
def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]:
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
if isinstance(message, LCHumanMessageWithFiles):
file_prompt_message_contents = []
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
file_prompt_message_contents.append(ImagePromptMessageContent(
data=file.data,
detail=ImagePromptMessageContent.DETAIL.HIGH
if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW
))
prompt_message_contents = [TextPromptMessageContent(data=message.content)]
prompt_message_contents.extend(file_prompt_message_contents)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=message.content))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content
}
if 'function_call' in message.additional_kwargs:
message_kwargs['tool_calls'] = [
AssistantPromptMessage.ToolCall(
id=message.additional_kwargs['function_call']['id'],
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=message.additional_kwargs['function_call']['name'],
arguments=message.additional_kwargs['function_call']['arguments']
)
)
]
prompt_messages.append(AssistantPromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(SystemPromptMessage(content=message.content))
elif isinstance(message, FunctionMessage):
prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name))
return prompt_messages
def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]:
messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, str):
messages.append(HumanMessage(content=prompt_message.content))
else:
message_contents = []
for content in prompt_message.content:
if isinstance(content, TextPromptMessageContent):
message_contents.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
message_contents.append({
'type': 'image',
'data': content.data,
'detail': content.detail.value
})
messages.append(HumanMessage(content=message_contents))
elif isinstance(prompt_message, AssistantPromptMessage):
message_kwargs = {
'content': prompt_message.content
}
if prompt_message.tool_calls:
message_kwargs['additional_kwargs'] = {
'function_call': {
'id': prompt_message.tool_calls[0].id,
'name': prompt_message.tool_calls[0].function.name,
'arguments': prompt_message.tool_calls[0].function.arguments
}
}
messages.append(AIMessage(**message_kwargs))
elif isinstance(prompt_message, SystemPromptMessage):
messages.append(SystemMessage(content=prompt_message.content))
elif isinstance(prompt_message, ToolPromptMessage):
messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content))
return messages

View File

@ -1,8 +1,7 @@
import json
import logging
from langchain.schema import OutputParserException
from core.llm_generator.output_parser.errors import OutputParserException
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT

View File

@ -0,0 +1,2 @@
class OutputParserException(Exception):
pass

View File

@ -1,12 +1,11 @@
from typing import Any
from langchain.schema import BaseOutputParser, OutputParserException
from core.llm_generator.output_parser.errors import OutputParserException
from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE
from libs.json_in_md_parser import parse_and_check_json_markdown
class RuleConfigGeneratorOutputParser(BaseOutputParser):
class RuleConfigGeneratorOutputParser:
def get_format_instructions(self) -> str:
return RULE_CONFIG_GENERATE_TEMPLATE

View File

@ -2,12 +2,10 @@ import json
import re
from typing import Any
from langchain.schema import BaseOutputParser
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser):
class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT

View File

@ -13,17 +13,6 @@ class TwilioAPIWrapper(BaseModel):
and the environment variables ``TWILIO_ACCOUNT_SID``, ``TWILIO_AUTH_TOKEN``, and
``TWILIO_FROM_NUMBER``, or pass `account_sid`, `auth_token`, and `from_number` as
named parameters to the constructor.
Example:
.. code-block:: python
from langchain.utilities.twilio import TwilioAPIWrapper
twilio = TwilioAPIWrapper(
account_sid="ACxxx",
auth_token="xxx",
from_number="+10123456789"
)
twilio.run('test', '+12484345508')
"""
client: Any #: :meta private:

View File

@ -1,6 +1,6 @@
import json
from langchain.schema import OutputParserException
from core.llm_generator.output_parser.errors import OutputParserException
def parse_json_markdown(json_string: str) -> dict: