feat: support LLM understand video (#9828)

This commit is contained in:
非法操作 2024-11-08 13:22:52 +08:00 committed by GitHub
parent c9f785e00f
commit 033ab5490b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 69 additions and 20 deletions

View File

@ -285,8 +285,9 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Model Configuration
# Model configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
MULTIMODAL_SEND_VIDEO_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024

View File

@ -634,12 +634,17 @@ class IndexingConfig(BaseSettings):
)
class ImageFormatConfig(BaseSettings):
class VisionFormatConfig(BaseSettings):
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
default="base64",
)
class CeleryBeatConfig(BaseSettings):
CELERY_BEAT_SCHEDULER_TIME: int = Field(
@ -742,7 +747,7 @@ class FeatureConfig(
FileAccessConfig,
FileUploadConfig,
HttpConfig,
ImageFormatConfig,
VisionFormatConfig,
InnerAPIConfig,
IndexingConfig,
LoggingConfig,

View File

@ -3,7 +3,7 @@ import base64
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
from extensions.ext_database import db
from extensions.ext_storage import storage
@ -71,6 +71,12 @@ def to_prompt_message_content(f: File, /):
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
else:
data = _to_base64_data_string(f)
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")
@ -112,7 +118,7 @@ def _download_file_content(path: str, /):
def _get_encoded_string(f: File, /):
match f.transfer_method:
case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url)
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
content = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
@ -140,6 +146,8 @@ def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:

View File

@ -12,11 +12,13 @@ from .message_entities import (
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from .model_entities import ModelPropertyKey
__all__ = [
"ImagePromptMessageContent",
"VideoPromptMessageContent",
"PromptMessage",
"PromptMessageRole",
"LLMUsage",

View File

@ -56,6 +56,7 @@ class PromptMessageContentType(Enum):
TEXT = "text"
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
class PromptMessageContent(BaseModel):
@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.TEXT
class VideoPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO
data: str = Field(..., description="Base64 encoded video data")
format: str = Field(..., description="Video format")
class AudioPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO
data: str = Field(..., description="Base64 encoded audio data")

View File

@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
@ -431,6 +432,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
sub_message_dict = {"image": image_url}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.VIDEO:
message_content = cast(VideoPromptMessageContent, message_content)
video_url = message_content.data
if message_content.data.startswith("data:"):
raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
sub_message_dict = {"video": video_url}
sub_messages.append(sub_message_dict)
# resort sub_messages to ensure text is always at last
sub_messages = sorted(sub_messages, key=lambda x: "text" in x)

View File

@ -313,21 +313,35 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
return params
def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]:
if isinstance(prompt_message, str):
if isinstance(prompt_message, list):
sub_messages = []
for item in prompt_message:
if item.type == PromptMessageContentType.IMAGE:
sub_messages.append(
{
"type": "image_url",
"image_url": {"url": self._remove_base64_header(item.data)},
}
)
elif item.type == PromptMessageContentType.VIDEO:
sub_messages.append(
{
"type": "video_url",
"video_url": {"url": self._remove_base64_header(item.data)},
}
)
else:
sub_messages.append({"type": "text", "text": item.data})
return sub_messages
else:
return [{"type": "text", "text": prompt_message}]
return [
{"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}}
if item.type == PromptMessageContentType.IMAGE
else {"type": "text", "text": item.data}
for item in prompt_message
]
def _remove_base64_header(self, file_content: str) -> str:
if file_content.startswith("data:"):
data_split = file_content.split(";base64,")
return data_split[1]
def _remove_image_header(self, image: str) -> str:
if image.startswith("data:image"):
return image.split(",")[1]
return image
return file_content
def _handle_generate_response(
self,

View File

@ -14,6 +14,7 @@ from core.model_runtime.entities import (
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType
@ -560,7 +561,9 @@ class LLMNode(BaseNode[LLMNodeData]):
# cuz vision detail is related to the configuration from FileUpload feature.
content_item.detail = vision_detail
prompt_message_content.append(content_item)
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
elif isinstance(
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
):
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1:

View File

@ -468,8 +468,8 @@ const Configuration: FC = () => {
transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
},
enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled),
allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image],
allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`),
allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video],
allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`),
allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3,
fileUploadConfig: fileUploadConfigResponse,