From 34f844a298bb2b72109b8709c5e1458f79fb0c29 Mon Sep 17 00:00:00 2001 From: yicycyc <1258085915@qq.com> Date: Wed, 10 Jun 2026 07:10:10 +0000 Subject: [PATCH] Add Phi-4 multimodal support --- README.md | 9 +- docs/zh/model_capability.md | 2 + paddleformers/datasets/SFTDataset.py | 23 + paddleformers/datasets/collate.py | 49 + paddleformers/datasets/template/mm_plugin.py | 85 + paddleformers/datasets/template/template.py | 18 + paddleformers/transformers/__init__.py | 22 + .../transformers/auto/configuration.py | 10 + .../transformers/auto/feature_extraction.py | 4 + .../transformers/auto/image_processing.py | 4 + paddleformers/transformers/auto/modeling.py | 10 +- paddleformers/transformers/auto/processing.py | 4 + .../transformers/phi4_multimodal/__init__.py | 53 + .../phi4_multimodal/configuration.py | 286 +++ .../phi4_multimodal/feature_extraction.py | 189 ++ .../phi4_multimodal/image_processor.py | 285 +++ .../transformers/phi4_multimodal/modeling.py | 1955 +++++++++++++++++ .../transformers/phi4_multimodal/processor.py | 150 ++ 18 files changed, 3155 insertions(+), 3 deletions(-) create mode 100644 paddleformers/transformers/phi4_multimodal/__init__.py create mode 100644 paddleformers/transformers/phi4_multimodal/configuration.py create mode 100644 paddleformers/transformers/phi4_multimodal/feature_extraction.py create mode 100644 paddleformers/transformers/phi4_multimodal/image_processor.py create mode 100644 paddleformers/transformers/phi4_multimodal/modeling.py create mode 100644 paddleformers/transformers/phi4_multimodal/processor.py diff --git a/README.md b/README.md index 76e7ea002a5..223642ab27f 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform 结合业界主流优化方法与飞桨在业务实践中积累的高效特性,PaddleFormers 致力于打造**高性能、低资源占用**的训练体验,帮助用户高效便捷地完成大模型训练,而无需关注底层复杂的优化细节。 ## 🆕最新更新 -* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先Megatron-LM 6%。 +* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先 Megatron-LM 6%。 * 2026.01.21 - PaddleFomers v1.0版本发布啦!我们提供了针对 LLM 和 VLM 等模型的训练能力,针对 DeepSeek-V3模型和 GLM-4.5-Air 等重点模型,我们实现了极致性能优化(训练性能明显超越 Megatron-LM )。针对 PaddleOCR-VL,我们在昆仑芯 P800、天数天垓150等国产计算芯片上进行了适配,更好的满足国内用户需求。 ## ✨特性 @@ -102,11 +102,16 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform - VLM + VLM 🏛️ERNIE-4.5-VL baidu/ERNIE-4.5-VL-28B-A3B-Base-PT、baidu/ERNIE-4.5-VL-28B-A3B-PT、baidu/ERNIE-4.5-VL-424B-A47B-Base-PT、baidu/ERNIE-4.5-VL-424B-A47B-PT、baidu/ERNIE-4.5-VL-28B-A3B-Thinking ernie_vl、ernie_vl_nothink + + Phi-4-multimodal + microsoft/Phi-4-multimodal-instruct + phi4_multimodal + 🏛️PaddleOCR-VL PaddlePaddle/PaddleOCR-VL diff --git a/docs/zh/model_capability.md b/docs/zh/model_capability.md index 3b402f9d4b2..ba5beca392f 100644 --- a/docs/zh/model_capability.md +++ b/docs/zh/model_capability.md @@ -12,6 +12,7 @@ |Qwen3|✓|✓|✓|✓|✓| |Qwen3-Next|✓|✓|✓|✓|✓| |🏛️ERNIE-4.5-VL|x|✓|✓|x|x| +|Phi-4-multimodal|x|✓|✓|x|x| |🏛️PaddleOCR-VL|x|✓|✓|x|x| |Qwen2.5-VL|x|✓|✓|x|x| |Qwen3-VL|x|✓|✓|x|x| @@ -30,6 +31,7 @@ |Qwen3|✓|✓|✓|✓|✓|✓| |Qwen3-Next|✓|✓|✓|x|✓|✓| |🏛️ERNIE-4.5-VL|✓|✓|✓|x|✓|✓| +|Phi-4-multimodal|x|x|-|x|✓|✓| |🏛️PaddleOCR-VL|x|x|-|x|✓|✓| |Qwen2.5-VL|✓|x|-|x|✓|✓| |Qwen3-VL|x|x|✓|x|✓|✓| diff --git a/paddleformers/datasets/SFTDataset.py b/paddleformers/datasets/SFTDataset.py index 186d4ec0780..fb049da23eb 100644 --- a/paddleformers/datasets/SFTDataset.py +++ b/paddleformers/datasets/SFTDataset.py @@ -618,6 +618,7 @@ def _process_pretraining_sequence(self, example, actual_example_num): # label shift labels = labels[1:] + [-100] + labels = self._mask_mm_token_labels(tokens, labels) pos_ids = list(range(len(tokens))) # only pure text, mm_position_ids will be reconstructed in collate.py @@ -808,6 +809,7 @@ def _process_sft_sequence(self, example, actual_example_num): # label shift labels = labels[1:] + [-100] + labels = self._mask_mm_token_labels(tokens, labels) pos_ids = list(range(len(tokens))) @@ -887,6 +889,27 @@ def _truncate( return input_ids, labels + def _mask_mm_token_labels(self, tokens, labels): + if not self.use_template or self.template_backend == "jinja" or self.template is None: + return labels + + mm_plugin = getattr(self.template, "mm_plugin", None) + if not getattr(mm_plugin, "mask_mm_token_labels", False): + return labels + + mm_tokens = [token for token in [mm_plugin.image_token, mm_plugin.audio_token] if token is not None] + if not mm_tokens: + return labels + + mm_token_ids = self.tokenizer.convert_tokens_to_ids(mm_tokens) + if not isinstance(mm_token_ids, list): + mm_token_ids = [mm_token_ids] + mm_token_ids = {token_id for token_id in mm_token_ids if token_id is not None} + for i, token in enumerate(tokens): + if token in mm_token_ids or labels[i] in mm_token_ids: + labels[i] = -100 + return labels + def _encode_truncated(self, input_ids, labels): length = self._get_length(input_ids, labels) if self.max_seq_len is not None and length > self.max_seq_len: diff --git a/paddleformers/datasets/collate.py b/paddleformers/datasets/collate.py index 30be4d57c82..ee6779e9e0e 100644 --- a/paddleformers/datasets/collate.py +++ b/paddleformers/datasets/collate.py @@ -625,6 +625,13 @@ def mm_collate_fn( input_keys.append("video_grid_thw") input_keys.append("input_features") input_keys.append("feature_attention_mask") + input_keys.append("image_pixel_values") + input_keys.append("image_sizes") + input_keys.append("image_attention_mask") + input_keys.append("audio_input_features") + input_keys.append("audio_embed_sizes") + input_keys.append("audio_attention_mask") + input_keys.append("input_mode") if training_args.num_nextn_predict_layers > 0: input_keys.append("nbatch_pack_offset") @@ -652,6 +659,13 @@ def mm_collate_fn( video_grid_thw = [] input_features = [] feature_attention_mask = [] + image_pixel_values = [] + image_sizes = [] + image_attention_mask = [] + audio_input_features = [] + audio_embed_sizes = [] + audio_attention_mask = [] + input_mode = [] for seq in batch_sequence: original_token_ids.append(seq.token_ids) mm_inputs = seq.mm_inputs @@ -667,6 +681,20 @@ def mm_collate_fn( input_features.append(mm_inputs["input_features"]) if "feature_attention_mask" in mm_inputs: feature_attention_mask.append(mm_inputs["feature_attention_mask"]) + if "image_pixel_values" in mm_inputs: + image_pixel_values.append(mm_inputs["image_pixel_values"]) + if "image_sizes" in mm_inputs: + image_sizes.append(mm_inputs["image_sizes"]) + if "image_attention_mask" in mm_inputs: + image_attention_mask.append(mm_inputs["image_attention_mask"]) + if "audio_input_features" in mm_inputs: + audio_input_features.append(mm_inputs["audio_input_features"]) + if "audio_embed_sizes" in mm_inputs: + audio_embed_sizes.append(mm_inputs["audio_embed_sizes"]) + if "audio_attention_mask" in mm_inputs: + audio_attention_mask.append(mm_inputs["audio_attention_mask"]) + if "input_mode" in mm_inputs: + input_mode.append(mm_inputs["input_mode"]) if get_rope_func is not None: filtered_args = {k: paddle.to_tensor(mm_inputs[k]) for k in func_params if k in mm_inputs} total_input_ids = paddle.to_tensor([seq.token_ids]) @@ -712,6 +740,20 @@ def mm_collate_fn( input_features = paddle.concat(input_features, axis=0) if len(feature_attention_mask) > 0: feature_attention_mask = paddle.concat(feature_attention_mask, axis=0) + if len(image_pixel_values) > 0: + image_pixel_values = paddle.concat(image_pixel_values, axis=0) + if len(image_sizes) > 0: + image_sizes = paddle.concat(image_sizes, axis=0) + if len(image_attention_mask) > 0: + image_attention_mask = paddle.concat(image_attention_mask, axis=0) + if len(audio_input_features) > 0: + audio_input_features = paddle.concat(audio_input_features, axis=0) + if len(audio_embed_sizes) > 0: + audio_embed_sizes = paddle.concat(audio_embed_sizes, axis=0) + if len(audio_attention_mask) > 0: + audio_attention_mask = paddle.concat(audio_attention_mask, axis=0) + if len(input_mode) > 0: + input_mode = paddle.concat(input_mode, axis=0) if get_token_type_func is not None: # ernie45vl bs_idx_in_rope = 0 padded_position_ids = padded_position_ids.transpose([1, 2, 0]) @@ -736,6 +778,13 @@ def mm_collate_fn( video_grid_thw, input_features, feature_attention_mask, + image_pixel_values, + image_sizes, + image_attention_mask, + audio_input_features, + audio_embed_sizes, + audio_attention_mask, + input_mode, ] ) diff --git a/paddleformers/datasets/template/mm_plugin.py b/paddleformers/datasets/template/mm_plugin.py index 3b5c124237f..83a116a90b3 100644 --- a/paddleformers/datasets/template/mm_plugin.py +++ b/paddleformers/datasets/template/mm_plugin.py @@ -392,6 +392,90 @@ def get_mm_inputs( return self._get_mm_inputs(images, videos, audios, processor, **kwargs) +@dataclass +class Phi4MultimodalPlugin(BasePlugin): + mask_mm_token_labels: bool = True + + @staticmethod + def _as_int(value): + if hasattr(value, "item"): + value = value.item() + return int(value) + + @override + def _validate_input(self, processor, images, videos, audios) -> None: + if len(videos) != 0: + raise ValueError("Phi-4 multimodal does not support video input in this template.") + if len(images) != 0: + if processor is None: + raise ValueError("Processor was not found for Phi-4 multimodal image input.") + if getattr(processor, "image_processor", None) is None: + raise ValueError("Image processor was not found for Phi-4 multimodal image input.") + if len(audios) != 0: + if processor is None: + raise ValueError("Processor was not found for Phi-4 multimodal audio input.") + if getattr(processor, "feature_extractor", None) is None: + raise ValueError("Audio feature extractor was not found for Phi-4 multimodal audio input.") + + @override + def _get_mm_inputs(self, images, videos, audios, processor, **kwargs): + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(processor.image_processor(images, return_tensors="pd")) + + if len(audios) != 0: + feature_extractor = getattr(processor, "feature_extractor", None) + sampling_rate = getattr( + feature_extractor, "sampling_rate", getattr(processor, "audio_sampling_rate", 16000) + ) + audios = self._regularize_audios(audios, sampling_rate=sampling_rate)["audios"] + mm_inputs.update( + feature_extractor( + audios, + sampling_rate=sampling_rate, + return_attention_mask=True, + return_tensors="pd", + ) + ) + + if len(images) != 0 or len(audios) != 0: + input_mode = 3 if len(images) != 0 and len(audios) != 0 else 1 if len(images) != 0 else 2 + mm_inputs["input_mode"] = paddle.to_tensor([input_mode], dtype="int64") + + return mm_inputs + + @override + def process_messages(self, messages, images, videos, audios, mm_inputs, processor): + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + messages = deepcopy(messages) + num_img_tokens = mm_inputs.get("num_img_tokens", []) + audio_embed_sizes = mm_inputs.get("audio_embed_sizes", []) + num_image_tokens, num_audio_tokens = 0, 0 + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = self._as_int(num_img_tokens[num_image_tokens]) if self.expand_mm_tokens else 1 + content = content.replace(IMAGE_PLACEHOLDER, self.image_token * image_seqlen, 1) + num_image_tokens += 1 + + while AUDIO_PLACEHOLDER in content: + audio_seqlen = self._as_int(audio_embed_sizes[num_audio_tokens]) if self.expand_mm_tokens else 1 + content = content.replace(AUDIO_PLACEHOLDER, self.audio_token * audio_seqlen, 1) + num_audio_tokens += 1 + + message["content"] = content + + self.masked_tokens = [token for token in [self.image_token, self.audio_token] if token is not None] + return messages + + @dataclass class PaddleOCRVLPlugin(BasePlugin): image_bos_token: str = "<|IMAGE_START|>" @@ -1496,6 +1580,7 @@ def process_messages( PLUGINS = { "base": BasePlugin, + "phi4_multimodal": Phi4MultimodalPlugin, "ernie_vl": ErnieVLPlugin, "qwen2_vl": Qwen2VLPlugin, "paddleocr_vl": PaddleOCRVLPlugin, diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py index 1c3757e34ad..c40c615b47b 100644 --- a/paddleformers/datasets/template/template.py +++ b/paddleformers/datasets/template/template.py @@ -977,6 +977,24 @@ def _get_gpt_oss_prefix(): chat_sep="<|im_end|>", ) +register_template( + name="phi4_multimodal", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), + format_assistant=StringFormatter(slots=["{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), + format_observation=StringFormatter( + slots=["<|user|>\n\n{{content}}\n<|end|>\n<|assistant|>\n"] + ), + suffix=["<|end|>"], + chat_sep="<|end|>\n", + auto_add_bos=False, + mm_plugin=get_mm_plugin( + name="phi4_multimodal", + image_token="<|endoftext10|>", + audio_token="<|endoftext11|>", + ), +) + register_template( name="glm_ocr", format_user=StringFormatter(slots=["<|user|>\n{{content}}\n"]), diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index feeff603775..014ea816320 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -314,6 +314,27 @@ "phi3.configuration": ["Phi3Config"], "phi3.tokenizer": ["Phi3Tokenizer"], "phi3.modeling": ["Phi3Model", "Phi3ForCausalLM", "Phi3ForCausalLMPipe"], + "phi4_multimodal.configuration": [ + "Phi4MultimodalAudioConfig", + "Phi4MultimodalConfig", + "Phi4MultimodalVisionConfig", + ], + "phi4_multimodal.feature_extraction": ["Phi4MultimodalFeatureExtractor"], + "phi4_multimodal.image_processor": ["Phi4MultimodalImageProcessor"], + "phi4_multimodal.modeling": [ + "Phi4Multimodal", + "Phi4MultimodalAudio", + "Phi4MultimodalAudioModel", + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalForCausalLM", + "Phi4MultimodalForConditionalGeneration", + "Phi4MultimodalModel", + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalVision", + "Phi4MultimodalVisionModel", + "Phi4MultimodalVisionPreTrainedModel", + ], + "phi4_multimodal.processor": ["Phi4MultimodalProcessor"], "glm4v_moe.configuration": ["Glm4vMoeConfig", "Glm4vMoeTextConfig", "Glm4vMoeVisionConfig"], "glm4v_moe.modeling": [ "Glm4vMoeForConditionalGeneration", @@ -408,6 +429,7 @@ from .minimax_m2 import * from .gpt_oss import * from .phi3 import * + from .phi4_multimodal import * from .gemma3_text import * from .glm_ocr import * else: diff --git a/paddleformers/transformers/auto/configuration.py b/paddleformers/transformers/auto/configuration.py index fc8c594f4cb..2a3cc026b01 100644 --- a/paddleformers/transformers/auto/configuration.py +++ b/paddleformers/transformers/auto/configuration.py @@ -56,6 +56,10 @@ ("minimax_m2", "MiniMaxM2Config"), ("gpt_oss", "GptOssConfig"), ("phi3", "Phi3Config"), + ("phi4mm", "Phi4MultimodalConfig"), + ("phi4_multimodal", "Phi4MultimodalConfig"), + ("phi4_multimodal_audio", "Phi4MultimodalAudioConfig"), + ("phi4_multimodal_vision", "Phi4MultimodalVisionConfig"), ("gemma3_text", "Gemma3TextConfig"), ("glm4v_moe", "Glm4vMoeConfig"), ("glm_ocr", "GlmOcrConfig"), @@ -87,6 +91,9 @@ ("qwen3_vl_moe", "Qwen3VLMoe"), ("qwen3_vl_moe_text", "Qwen3VLMoeText"), ("glm_ocr", "GlmOcrForConditionalGeneration"), + ("phi4_multimodal", "Phi4Multimodal"), + ("phi4_multimodal_audio", "Phi4MultimodalAudio"), + ("phi4_multimodal_vision", "Phi4MultimodalVision"), ("qwen3_5_moe", "Qwen3_5MoEForConditionalGeneration"), ("qwen3_5", "Qwen3_5ForConditionalGeneration"), ] @@ -100,6 +107,9 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict( [ ("qwen2_5_vl_text", "qwen2_5_vl"), + ("phi4_multimodal_audio", "phi4_multimodal"), + ("phi4_multimodal_vision", "phi4_multimodal"), + ("phi4mm", "phi4_multimodal"), ("qwen3_vl_text", "qwen3_vl"), ("qwen3_vl_moe_text", "qwen3_vl_moe"), ] diff --git a/paddleformers/transformers/auto/feature_extraction.py b/paddleformers/transformers/auto/feature_extraction.py index 555237c4142..33252ffedce 100644 --- a/paddleformers/transformers/auto/feature_extraction.py +++ b/paddleformers/transformers/auto/feature_extraction.py @@ -39,6 +39,7 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict( [ + ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"), ("whisper", "WhisperFeatureExtractor"), ] ) @@ -56,6 +57,9 @@ def safe_load_json_file(json_file: str): def feature_extractor_class_from_name(class_name: str): + if class_name == "Phi4MMAudioFeatureExtractor": + class_name = "Phi4MultimodalFeatureExtractor" + for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items(): if class_name in extractors: module_name = model_type_to_module_name(module_name) diff --git a/paddleformers/transformers/auto/image_processing.py b/paddleformers/transformers/auto/image_processing.py index 4244259c0e1..39fc1307141 100644 --- a/paddleformers/transformers/auto/image_processing.py +++ b/paddleformers/transformers/auto/image_processing.py @@ -55,6 +55,7 @@ "glm4v_moe": ("Glm4vImageProcessor", "Glm4vImageProcessorFast"), "kimi_k25": ("KimiK25VisionProcessor"), "paddleocr_vl": ("PaddleOCRVLImageProcessor"), + "phi4_multimodal": ("Phi4MultimodalImageProcessor"), "qwen2_5_vl": ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast"), "qwen2_vl": ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast"), "qwen3_vl": ("Qwen3VLImageProcessor", "Qwen3VLImageProcessorFast"), @@ -68,6 +69,9 @@ def get_image_processor_class_from_name(class_name: str): + if class_name == "Phi4MMImageProcessor": + class_name = "Phi4MultimodalImageProcessor" + if class_name == "BaseImageProcessorFast": return BaseImageProcessorFast diff --git a/paddleformers/transformers/auto/modeling.py b/paddleformers/transformers/auto/modeling.py index f450dc95656..67bb875e96e 100644 --- a/paddleformers/transformers/auto/modeling.py +++ b/paddleformers/transformers/auto/modeling.py @@ -75,6 +75,10 @@ ("MiniMaxM2", "minimax_m2"), ("GptOss", "gpt_oss"), ("Phi3", "phi3"), + ("Phi4MM", "phi4_multimodal"), + ("Phi4Multimodal", "phi4_multimodal"), + ("Phi4MultimodalAudio", "phi4_multimodal"), + ("Phi4MultimodalVision", "phi4_multimodal"), ("Gemma3", "gemma3_text"), ("Glm4vMoe", "glm4v_moe"), ("GlmOcr", "glm_ocr"), @@ -82,7 +86,11 @@ ) MAPPING_SPACIAL_KEY = OrderedDict( - [("Gemma3", "Gemma3Text"), ("Ernie4_5_VLMoe", "Ernie4_5_VLMoeForConditionalGeneration")] + [ + ("Gemma3", "Gemma3Text"), + ("Ernie4_5_VLMoe", "Ernie4_5_VLMoeForConditionalGeneration"), + ("Phi4MM", "Phi4Multimodal"), + ] ) CONFIGURATION_MODEL_MAPPING = OrderedDict([((), "Gemma3TextModel")]) diff --git a/paddleformers/transformers/auto/processing.py b/paddleformers/transformers/auto/processing.py index bca898e350d..40f8d1bf18d 100644 --- a/paddleformers/transformers/auto/processing.py +++ b/paddleformers/transformers/auto/processing.py @@ -49,6 +49,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("kimi_k25", "KimiK25Processor"), + ("phi4_multimodal", "Phi4MultimodalProcessor"), ("qwen2_5_vl", "Qwen2_5_VLProcessor"), ("qwen3_vl", "Qwen3VLProcessor"), ("qwen2_vl", "Qwen2VLProcessor"), @@ -64,6 +65,9 @@ def processor_class_from_name(class_name: str): + if class_name == "Phi4MMProcessor": + class_name = "Phi4MultimodalProcessor" + for module_name, extractors in PROCESSOR_MAPPING_NAMES.items(): if class_name in extractors: module_name = model_type_to_module_name(module_name) diff --git a/paddleformers/transformers/phi4_multimodal/__init__.py b/paddleformers/transformers/phi4_multimodal/__init__.py new file mode 100644 index 00000000000..1e60a074bd5 --- /dev/null +++ b/paddleformers/transformers/phi4_multimodal/__init__.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Package""" +import sys +from typing import TYPE_CHECKING + +from ...utils.lazy_import import _LazyModule + +import_structure = { + "configuration": [ + "Phi4MultimodalConfig", + "Phi4MultimodalVisionConfig", + "Phi4MultimodalAudioConfig", + ], + "feature_extraction": ["Phi4MultimodalFeatureExtractor"], + "image_processor": ["Phi4MultimodalImageProcessor"], + "modeling": [ + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalModel", + "Phi4MultimodalForCausalLM", + "Phi4MultimodalForConditionalGeneration", + "Phi4MultimodalForCausalLMPipe", + "Phi4MMForCausalLM", + "Phi4MMForConditionalGeneration", + "Phi4MMForCausalLMPipe", + ], + "processor": ["Phi4MultimodalProcessor"], +} + +if TYPE_CHECKING: + from .configuration import * + from .feature_extraction import * + from .image_processor import * + from .modeling import * + from .processor import * +else: + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + ) diff --git a/paddleformers/transformers/phi4_multimodal/configuration.py b/paddleformers/transformers/phi4_multimodal/configuration.py new file mode 100644 index 00000000000..0d71c9e6948 --- /dev/null +++ b/paddleformers/transformers/phi4_multimodal/configuration.py @@ -0,0 +1,286 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Phi-4-Multimodal configuration.""" + +import math + +from ..configuration_utils import PretrainedConfig + + +def _convert_phi4mm_config(config_dict, with_lora_adapters=True): + config = dict(config_dict) + + config.pop("_name_or_path", None) + config.pop("architectures", None) + config.pop("auto_map", None) + vision_lora = config.pop("vision_lora", None) or {} + speech_lora = config.pop("speech_lora", None) or {} + config.pop("transformers_version", None) + config.pop("_attn_implementation", None) + config.pop("torch_dtype", None) + config.pop("model_type", None) + + embd_layer = config.pop("embd_layer") + audio_embd_layer = embd_layer["audio_embd_layer"] + vision_embd_layer = embd_layer["image_embd_layer"] + + audio_config = config.pop("audio_processor")["config"] + audio_config.pop("activation_checkpointing", None) + audio_config.pop("cnn_layer_norm", None) + audio_config.pop("input_layer", None) + audio_config.pop("batch_norm", None) + audio_config.pop("encoder_embedding_config", None) + audio_config.pop("ext_pw_kernel_size", None) + audio_config.pop("bias_in_glu", None) + audio_config.pop("causal", None) + + audio_config["hidden_size"] = audio_config.pop("attention_dim") + audio_config["num_attention_heads"] = audio_config.pop("attention_heads") + audio_config["intermediate_size"] = audio_config.pop("linear_units") + audio_config["nemo_conv_channels"] = audio_config.pop("nemo_conv_settings")["conv_channels"] + audio_config["bias_max_distance"] = audio_config.pop("relative_attention_bias_args")["t5_bias_max_distance"] + audio_config["downsample_rate"] = audio_embd_layer["downsample_rate"] + audio_config.pop("depthwise_seperable_out_channel", None) + + if "depthwise_separable_out_channel" not in audio_config: + audio_config["depthwise_separable_out_channel"] = audio_config.get("ext_pw_out_channel") + + config["audio_config"] = audio_config + config["vision_config"] = {"crop_size": vision_embd_layer["crop_size"]} + config["eos_token_id"] = [199999, 200020] + + if with_lora_adapters: + config.update( + { + "vision_lora_rank": vision_lora.get("r", 0), + "vision_lora_alpha": vision_lora.get("lora_alpha", 1), + "speech_lora_rank": speech_lora.get("r", 0), + "speech_lora_alpha": speech_lora.get("lora_alpha", 1), + } + ) + return config + + +class Phi4MultimodalVisionConfig(PretrainedConfig): + model_type = "phi4_multimodal_vision" + + def __init__( + self, + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=448, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + crop_size=448, + image_token_id=200010, + feature_layer=-2, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + self.crop_size = crop_size + self.image_token_id = image_token_id + self.feature_layer = feature_layer + + +class Phi4MultimodalAudioConfig(PretrainedConfig): + model_type = "phi4_multimodal_audio" + + def __init__( + self, + hidden_size=1024, + intermediate_size=1536, + num_blocks=24, + num_attention_heads=16, + activation="swish", + chunk_size=-1, + left_chunk=18, + dropout_rate=0.0, + ext_pw_out_channel=1024, + depthwise_separable_out_channel=1024, + depthwise_multiplier=1, + kernel_size=3, + conv_activation="swish", + input_size=80, + conv_glu_type="swish", + time_reduction=8, + bias_max_distance=1000, + bias_symmetric=False, + nemo_activation="relu", + nemo_conv_channels=1024, + downsample_rate=1, + initializer_range=0.02, + audio_token_id=200011, + feature_layer=-2, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_blocks = num_blocks + self.num_attention_heads = num_attention_heads + self.activation = activation + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.dropout_rate = dropout_rate + self.ext_pw_out_channel = ext_pw_out_channel + self.depthwise_separable_out_channel = depthwise_separable_out_channel + self.depthwise_multiplier = depthwise_multiplier + self.kernel_size = kernel_size + self.conv_activation = conv_activation + self.input_size = input_size + self.conv_glu_type = conv_glu_type + self.time_reduction = time_reduction + self.bias_max_distance = bias_max_distance + self.bias_symmetric = bias_symmetric + self.nemo_activation = nemo_activation + self.nemo_conv_channels = nemo_conv_channels + self.downsample_rate = downsample_rate + self.initializer_range = initializer_range + self.audio_token_id = audio_token_id + self.feature_layer = feature_layer + + nemo_final_size = self.input_size + for _ in range(int(math.log2(self.time_reduction))): + nemo_final_size = math.floor((nemo_final_size - 1) / 2 + 1) + self.nemo_final_size = nemo_final_size + + +class Phi4MultimodalConfig(PretrainedConfig): + model_type = "phi4_multimodal" + + @classmethod + def from_dict(cls, config_dict, **kwargs): + if config_dict.get("model_type") == "phi4mm": + config_dict = _convert_phi4mm_config(config_dict) + return super().from_dict(config_dict, **kwargs) + + def __init__( + self, + vocab_size=200064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=24, + num_key_value_heads=8, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=131072, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + rope_parameters=None, + bos_token_id=199999, + eos_token_id=None, + pad_token_id=199999, + sliding_window=None, + partial_rotary_factor=1.0, + vision_config=None, + audio_config=None, + attention_bias=False, + mlp_bias=False, + lm_head_bias=False, + vision_lora_rank=0, + vision_lora_alpha=1, + speech_lora_rank=0, + speech_lora_alpha=1, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id if eos_token_id is not None else [199999, 200020], + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.sliding_window = sliding_window + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.lm_head_bias = lm_head_bias + self.vision_lora_rank = vision_lora_rank + self.vision_lora_alpha = vision_lora_alpha + self.speech_lora_rank = speech_lora_rank + self.speech_lora_alpha = speech_lora_alpha + self._active_lora_adapter = None + + if isinstance(vision_config, dict): + self.vision_config = Phi4MultimodalVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Phi4MultimodalVisionConfig() + else: + self.vision_config = vision_config + + if isinstance(audio_config, dict): + self.audio_config = Phi4MultimodalAudioConfig(**audio_config) + elif audio_config is None: + self.audio_config = Phi4MultimodalAudioConfig() + else: + self.audio_config = audio_config + + # Build rope_parameters dict for compatibility with rope utils + self.rope_parameters = rope_parameters if rope_parameters is not None else self._build_rope_parameters() + + def _build_rope_parameters(self): + rope_params = { + "rope_theta": self.rope_theta, + "partial_rotary_factor": self.partial_rotary_factor, + } + if self.rope_scaling is not None: + rope_params.update(self.rope_scaling) + if "rope_type" not in rope_params: + rope_params["rope_type"] = "longrope" + if rope_params.get("rope_type") in {"longrope", "yarn", "llama3"}: + rope_params.setdefault("original_max_position_embeddings", self.original_max_position_embeddings) + else: + rope_params["rope_type"] = "default" + return rope_params diff --git a/paddleformers/transformers/phi4_multimodal/feature_extraction.py b/paddleformers/transformers/phi4_multimodal/feature_extraction.py new file mode 100644 index 00000000000..37ca74fe42b --- /dev/null +++ b/paddleformers/transformers/phi4_multimodal/feature_extraction.py @@ -0,0 +1,189 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Phi-4 Multimodal audio.""" + +import numpy as np +import paddle +from transformers.audio_utils import mel_filter_bank + +from ...utils.log import logger +from ..audio_processing_utils import SequenceFeatureExtractor +from ..feature_extraction_utils import BatchFeature + + +class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor): + model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + hop_length=160, + n_fft=512, + win_length=400, + preemphasis=0.97, + padding_value=0.0, + audio_compression_rate=8, + audio_downsample_rate=1, + audio_feat_stride=1, + mel_min_frequency=0, + mel_max_frequency=7690, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.hop_length = hop_length + self.n_fft = n_fft + self.win_length = win_length + self.preemphasis = preemphasis + self.padding_value = padding_value + self.audio_compression_rate = audio_compression_rate + self.audio_downsample_rate = audio_downsample_rate + self.audio_feat_stride = audio_feat_stride + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.n_fft // 2 + 1, + num_mel_filters=self.feature_size, + min_frequency=mel_min_frequency, + max_frequency=mel_max_frequency, + sampling_rate=self.sampling_rate, + triangularize_in_mel_space=True, + mel_scale="kaldi", + ).astype(np.float32) + + def __call__( + self, + raw_speech, + sampling_rate=None, + pad_to_multiple_of=None, + padding="longest", + max_length=None, + truncation=False, + return_tensors=None, + return_attention_mask=True, + device="cpu", + **kwargs, + ): + if sampling_rate is not None and sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor was trained using a sampling rate of " + f"{self.sampling_rate}. Please provide audio sampled at {self.sampling_rate}, not {sampling_rate}." + ) + if sampling_rate is None: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`." + ) + + speech_list = self._as_mono_float_list(raw_speech) + audio_lengths = np.asarray([speech.shape[0] for speech in speech_list], dtype=np.int64) + + if truncation and max_length is not None: + speech_list = [speech[:max_length] for speech in speech_list] + audio_lengths = np.minimum(audio_lengths, max_length) + + padded_length = max(int(length) for length in audio_lengths) + padded_length = max(padded_length, self.win_length) + if padding not in (True, "longest", "max_length"): + padded_length = max(int(audio_lengths[0]), self.win_length) + if max_length is not None and padding == "max_length": + padded_length = max(max_length, self.win_length) + if pad_to_multiple_of is not None and padded_length % pad_to_multiple_of != 0: + padded_length = ((padded_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + waveform = np.full((len(speech_list), padded_length), self.padding_value, dtype=np.float32) + for idx, speech in enumerate(speech_list): + length = min(speech.shape[0], padded_length) + waveform[idx, :length] = speech[:length] + + input_features = self._np_extract_fbank_features(waveform, audio_lengths) + feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1 + feature_lengths = np.maximum(feature_lengths, 1) * self.audio_feat_stride + audio_embed_sizes = self._compute_audio_embed_size(feature_lengths) + + feature_attention_mask = None + if return_attention_mask and len(feature_lengths) > 1: + max_feature_length = int(feature_lengths.max()) + feature_attention_mask = np.arange(max_feature_length)[None, :] < feature_lengths[:, None] + + data = { + "audio_input_features": input_features, + "audio_embed_sizes": audio_embed_sizes.astype(np.int64), + } + if feature_attention_mask is not None: + data["audio_attention_mask"] = feature_attention_mask + return BatchFeature(data=data, tensor_type=return_tensors) + + def _np_extract_fbank_features(self, waveform, audio_lengths): + fft_window = np.hamming(self.win_length).astype(np.float64) + batch_features = [] + max_frames = 0 + + for speech, length in zip(waveform, audio_lengths): + speech = speech[: max(int(length), self.win_length)] + if speech.shape[0] < self.win_length: + speech = np.pad(speech, (0, self.win_length - speech.shape[0]), constant_values=self.padding_value) + num_frames = (speech.shape[0] - self.win_length) // self.hop_length + 1 + frames = np.stack( + [speech[i * self.hop_length : i * self.hop_length + self.win_length] for i in range(num_frames)], + axis=0, + ) + frames_prev = np.roll(frames, 1, axis=-1) + frames_prev[:, 0] = frames_prev[:, 1] + frames = (frames - self.preemphasis * frames_prev) * 32768 + spectrum = np.fft.rfft(fft_window * frames, n=self.n_fft, axis=1).astype(np.complex64) + spec_power = np.abs(spectrum) ** 2 + log_spec = np.log(np.clip(spec_power @ self.mel_filters, a_min=1.0, a_max=None)).astype(np.float32) + batch_features.append(log_spec) + max_frames = max(max_frames, log_spec.shape[0]) + + padded = np.full((len(batch_features), max_frames, self.feature_size), self.padding_value, dtype=np.float32) + for idx, features in enumerate(batch_features): + padded[idx, : features.shape[0]] = features + return padded + + def _compute_audio_embed_size(self, audio_frames): + integer = audio_frames // self.audio_compression_rate + remainder = audio_frames % self.audio_compression_rate + result = integer + (remainder > 0).astype(integer.dtype) + integer = result // self.audio_downsample_rate + remainder = result % self.audio_downsample_rate + return integer + (remainder > 0).astype(integer.dtype) + + @staticmethod + def _as_mono_float_list(raw_speech): + if isinstance(raw_speech, paddle.Tensor): + raw_speech = raw_speech.detach().cpu().numpy() + if isinstance(raw_speech, np.ndarray): + if raw_speech.ndim == 1: + raw_speech = [raw_speech] + elif raw_speech.ndim == 2: + raw_speech = list(raw_speech) + else: + raw_speech = [raw_speech.mean(axis=-1)] + elif isinstance(raw_speech, (list, tuple)): + raw_speech = [ + speech.detach().cpu().numpy() if isinstance(speech, paddle.Tensor) else speech for speech in raw_speech + ] + else: + raise ValueError(f"Unsupported audio input type: {type(raw_speech)}") + + result = [] + for speech in raw_speech: + speech = np.asarray(speech) + if speech.ndim > 1: + speech = speech.mean(axis=-1) + result.append(speech.astype(np.float32)) + return result + + +__all__ = ["Phi4MultimodalFeatureExtractor"] diff --git a/paddleformers/transformers/phi4_multimodal/image_processor.py b/paddleformers/transformers/phi4_multimodal/image_processor.py new file mode 100644 index 00000000000..80a17f46dc9 --- /dev/null +++ b/paddleformers/transformers/phi4_multimodal/image_processor.py @@ -0,0 +1,285 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Phi-4 Multimodal.""" + +import math + +import numpy as np +import paddle +from PIL import Image, ImageOps + +from ..feature_extraction_utils import BatchFeature +from ..image_processing_utils import PaddleImageProcessingMixin +from ..image_utils import PILImageResampling, make_flat_list_of_images + + +class Phi4MultimodalImageProcessor(PaddleImageProcessingMixin): + model_input_names = ["image_pixel_values", "image_sizes", "image_attention_mask"] + + def __init__( + self, + size=None, + patch_size=14, + dynamic_hd=36, + image_mean=None, + image_std=None, + do_resize=True, + do_rescale=True, + do_normalize=True, + do_convert_rgb=True, + resample=PILImageResampling.BICUBIC, + rescale_factor=1 / 255, + **kwargs, + ): + self.size = size or {"height": 448, "width": 448} + self.patch_size = patch_size + self.dynamic_hd = dynamic_hd + self.image_mean = image_mean or [0.5, 0.5, 0.5] + self.image_std = image_std or [0.5, 0.5, 0.5] + self.do_resize = do_resize + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.do_convert_rgb = do_convert_rgb + self.resample = resample + self.rescale_factor = rescale_factor + super().__init__(**kwargs) + + def __call__(self, images, **kwargs): + return self.preprocess(images, **kwargs) + + def preprocess( + self, + images, + size=None, + patch_size=None, + dynamic_hd=None, + image_mean=None, + image_std=None, + do_rescale=None, + do_normalize=None, + rescale_factor=None, + return_tensors=None, + **kwargs, + ): + size = size or self.size + if isinstance(size, int): + size = {"height": size, "width": size} + patch_size = patch_size if patch_size is not None else self.patch_size + dynamic_hd = dynamic_hd if dynamic_hd is not None else self.dynamic_hd + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_rescale = self.do_rescale if do_rescale is None else do_rescale + do_normalize = self.do_normalize if do_normalize is None else do_normalize + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + + height = size["height"] + width = size["width"] + if height != width: + raise ValueError("Phi4MultimodalImageProcessor only supports square sizes.") + + images = make_flat_list_of_images(images) + mask_size = height // patch_size + images_transformed = [] + masks_transformed = [] + image_tokens = [] + image_sizes = [] + + for image in images: + image = self._to_pil_image(image) + resized_image, attention_mask = self.dynamic_preprocess( + image, height, patch_size, mask_size, max_num=dynamic_hd + ) + processed_image = self._to_chw_array(resized_image) + if do_rescale: + processed_image = processed_image * rescale_factor + if do_normalize: + mean = np.asarray(image_mean, dtype=np.float32)[:, None, None] + std = np.asarray(image_std, dtype=np.float32)[:, None, None] + processed_image = (processed_image - mean) / std + + global_image = self._resize_chw(processed_image, height, height) + image_height, image_width = processed_image.shape[-2:] + mask_height, mask_width = attention_mask.shape[-2:] + global_attention_mask = np.ones((1, mask_size, mask_size), dtype=bool) + + hd_image = processed_image.reshape(1, 3, image_height // height, height, image_width // width, width) + hd_image = hd_image.transpose(0, 2, 4, 1, 3, 5).reshape(-1, 3, height, width) + + attention_mask = attention_mask.reshape( + mask_height // mask_size, mask_size, mask_width // mask_size, mask_size + ) + attention_mask = attention_mask.transpose(0, 2, 1, 3).reshape(-1, mask_size, mask_size) + + downsample_attention_mask = attention_mask[:, 0::2, 0::2] + pooled_mask_size = mask_size // 2 + mask_size % 2 + downsample_attention_mask = downsample_attention_mask.reshape( + mask_height // mask_size, + mask_width // mask_size, + pooled_mask_size, + pooled_mask_size, + ) + downsample_attention_mask = downsample_attention_mask.transpose(0, 2, 1, 3) + downsample_attention_mask = downsample_attention_mask.reshape( + downsample_attention_mask.shape[0] * downsample_attention_mask.shape[1], + downsample_attention_mask.shape[2] * downsample_attention_mask.shape[3], + ) + + base_feat_size = mask_size // 2 + mask_size % 2 + num_img_tokens = ( + base_feat_size**2 + + 1 + + int(downsample_attention_mask.sum().item()) + + int(downsample_attention_mask[:, 0].sum().item()) + + base_feat_size + ) + + hd_image = np.concatenate([global_image[None, ...], hd_image], axis=0) + attention_mask = np.concatenate([global_attention_mask, attention_mask], axis=0) + + images_transformed.append(hd_image.astype(np.float32)) + masks_transformed.append(attention_mask.astype(bool)) + image_tokens.append(num_img_tokens) + image_sizes.append([image_height, image_width]) + + max_crops = max(image.shape[0] for image in images_transformed) + images_transformed = np.stack( + [self._pad_images_to_max_crops(image, max_crops) for image in images_transformed], axis=0 + ) + masks_transformed = np.stack( + [self._pad_masks_to_max_crops(mask, max_crops) for mask in masks_transformed], axis=0 + ) + + data = { + "image_pixel_values": images_transformed, + "image_sizes": np.asarray(image_sizes, dtype=np.int64), + "image_attention_mask": masks_transformed, + "num_img_tokens": image_tokens, + } + return BatchFeature(data=data, tensor_type=return_tensors) + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, image_size, patch_size, mask_size, max_num=36, min_num=1): + orig_width, orig_height = image.size + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + aspect_ratio = orig_width / orig_height + target_ratios = { + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if min_num <= i * j <= max_num + } + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + new_size = (int(orig_width * ratio_height), target_height) + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + + attention_mask = np.ones( + (int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0])), + dtype=bool, + ) + if padding_width >= patch_size: + attention_mask[:, -math.floor(padding_width / patch_size) :] = False + if padding_height >= patch_size: + attention_mask[-math.floor(padding_height / patch_size) :, :] = False + + if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: + raise ValueError(f"the aspect ratio is very extreme {new_size}") + + image = image.resize((int(new_size[0]), int(new_size[1])), resample=self.resample) + resized_img = ImageOps.expand( + image, border=(0, 0, int(padding_width), int(padding_height)), fill=(255, 255, 255) + ) + return resized_img, attention_mask + + def _to_pil_image(self, image): + if isinstance(image, Image.Image): + pil_image = image + else: + if isinstance(image, paddle.Tensor): + image = image.detach().cpu().numpy() + image = np.asarray(image) + if image.ndim != 3: + raise ValueError(f"Expected image with 3 dimensions, got shape {image.shape}.") + if image.shape[0] in (1, 3) and image.shape[-1] not in (1, 3): + image = image.transpose(1, 2, 0) + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = image * 255 + image = np.clip(image, 0, 255).astype(np.uint8) + pil_image = Image.fromarray(image) + return pil_image.convert("RGB") if self.do_convert_rgb else pil_image + + @staticmethod + def _to_chw_array(image): + return np.asarray(image).astype(np.float32).transpose(2, 0, 1) + + def _resize_chw(self, image, height, width): + resized = [] + for channel in image: + pil_channel = Image.fromarray(channel.astype(np.float32), mode="F") + pil_channel = pil_channel.resize((width, height), resample=self.resample) + resized.append(np.asarray(pil_channel, dtype=np.float32)) + return np.stack(resized, axis=0) + + @staticmethod + def _pad_images_to_max_crops(images, max_crops): + if max_crops <= images.shape[0]: + return images + pad = np.zeros((max_crops - images.shape[0], *images.shape[1:]), dtype=images.dtype) + return np.concatenate([images, pad], axis=0) + + @staticmethod + def _pad_masks_to_max_crops(masks, max_crops): + if max_crops <= masks.shape[0]: + return masks + pad = np.ones((max_crops - masks.shape[0], *masks.shape[1:]), dtype=masks.dtype) + return np.concatenate([masks, pad], axis=0) + + +__all__ = ["Phi4MultimodalImageProcessor"] diff --git a/paddleformers/transformers/phi4_multimodal/modeling.py b/paddleformers/transformers/phi4_multimodal/modeling.py new file mode 100644 index 00000000000..295bfe31033 --- /dev/null +++ b/paddleformers/transformers/phi4_multimodal/modeling.py @@ -0,0 +1,1955 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Phi-4-Multimodal model.""" + +import math +from typing import Optional, Tuple, Union + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed.fleet.utils import recompute + +from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS +from ...nn.criterion.interface import CriterionLayer +from ...nn.embedding import Embedding as GeneralEmbedding +from ...nn.linear import Linear as GeneralLinear +from ...nn.lm_head import LMHead as GeneralLMHead +from ...nn.pp_model import GeneralModelForCausalLMPipe +from ...utils.log import logger +from ..activations import ACT2FN +from ..cache_utils import Cache, DynamicCache +from ..masking_utils import create_causal_mask_and_row_indices +from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ..model_utils import PretrainedModel, register_base_model +from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS +from .configuration import ( + Phi4MultimodalAudioConfig, + Phi4MultimodalConfig, + Phi4MultimodalVisionConfig, +) + +# ======================= Utility functions ======================= + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin) + k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin) + + q_embed = paddle.concat([q_embed, q_pass], axis=-1) + k_embed = paddle.concat([k_embed, k_pass], axis=-1) + return q_embed, k_embed + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand([batch, num_key_value_heads, n_rep, slen, head_dim]) + return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim]) + + +def _create_lora_parameter(layer: nn.Layer, shape): + return layer.create_parameter( + shape=shape, + default_initializer=nn.initializer.Constant(0.0), + ) + + +def _lora_delta(hidden_states: paddle.Tensor, lora_a: paddle.Tensor, lora_b: paddle.Tensor, scaling: float): + input_dtype = hidden_states.dtype + delta = F.linear(hidden_states, lora_a.astype(input_dtype)) + delta = F.linear(delta, lora_b.astype(input_dtype)) + return delta * scaling + + +def _active_lora_adapter(config): + adapter = getattr(config, "_active_lora_adapter", None) + if adapter in ("vision", "speech"): + return adapter + return None + + +def _lora_adapter_from_input_mode(input_mode, image_pixel_values=None, audio_input_features=None): + if input_mode is not None: + if isinstance(input_mode, paddle.Tensor): + input_mode = int(input_mode.flatten()[0].item()) + if input_mode in (1, 3): + return "vision" + if input_mode == 2: + return "speech" + return None + if image_pixel_values is not None: + return "vision" + if audio_input_features is not None: + return "speech" + return None + + +# ======================= Vision Encoder ======================= + + +class Phi4MultimodalVisionMLP(nn.Layer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Phi4MultimodalVisionAttention(nn.Layer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = list(input_shape) + [self.num_heads, self.head_dim] + + query_states = self.q_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + key_states = self.k_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + value_states = self.v_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + + attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) * self.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query_states.dtype) + if self.training and self.attention_dropout > 0: + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + attn_output = attn_output.reshape(list(input_shape) + [-1]) + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights + + +class Phi4MultimodalVisionEncoderLayer(nn.Layer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps) + self.self_attn = Phi4MultimodalVisionAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Phi4MultimodalVisionEncoder(nn.Layer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.layers = nn.LayerList([Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + inputs_embeds: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + output_hidden_states: bool = False, + ): + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + hidden_states = encoder_layer(hidden_states, attention_mask) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return hidden_states, all_hidden_states + + +class Phi4MultimodalVisionEmbeddings(nn.Layer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.patch_size = config.patch_size + self.num_patches_per_side = config.image_size // self.patch_size + + self.patch_embedding = nn.Conv2D( + in_channels=config.num_channels, + out_channels=config.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) + + def forward(self, pixel_values: paddle.Tensor, patch_attention_mask: paddle.Tensor) -> paddle.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose([0, 2, 1]) + + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = paddle.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side).astype( + pixel_values.dtype + ) + position_ids = paddle.full( + shape=[batch_size, max_nb_patches_h * max_nb_patches_w], fill_value=0, dtype="int64" + ) + + nb_patches_h = patch_attention_mask[:, :, 0].astype("int64").sum(axis=1) + nb_patches_w = patch_attention_mask[:, 0, :].astype("int64").sum(axis=1) + + step_h = 1.0 / nb_patches_h.astype("float32") + step_w = 1.0 / nb_patches_w.astype("float32") + + max_patches_h = patch_attention_mask.shape[1] + max_patches_w = patch_attention_mask.shape[2] + h_indices = paddle.arange(max_patches_h, dtype="float32") + w_indices = paddle.arange(max_patches_w, dtype="float32") + + fractional_coords_h = h_indices[None, :] * step_h[:, None] + fractional_coords_w = w_indices[None, :] * step_w[:, None] + + fractional_coords_h = paddle.clip(fractional_coords_h, max=(1.0 - 1e-6)) + fractional_coords_w = paddle.clip(fractional_coords_w, max=(1.0 - 1e-6)) + + fractional_coords_h = fractional_coords_h.astype(pixel_values.dtype) + fractional_coords_w = fractional_coords_w.astype(pixel_values.dtype) + + bucket_coords_h = paddle.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = paddle.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = bucket_coords_h[:, :, None] * self.num_patches_per_side + bucket_coords_w[:, None, :] + pos_ids = pos_ids.reshape([batch_size, -1]) + + flat_mask = patch_attention_mask.reshape([batch_size, -1]).astype("bool") + for i in range(batch_size): + mask_i = flat_mask[i] + position_ids[i][mask_i] = pos_ids[i][mask_i] + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Layer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.probe = self.create_parameter( + shape=[1, 1, config.hidden_size], + default_initializer=nn.initializer.Normal(std=1.0), + ) + self.attention = nn.MultiHeadAttention(config.hidden_size, config.num_attention_heads) + self.layernorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.mlp = Phi4MultimodalVisionMLP(config) + + def forward(self, hidden_state: paddle.Tensor, attention_mask: paddle.Tensor) -> paddle.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.expand([batch_size, -1, -1]) + + # attention_mask: [B, S] bool -> key_padding_mask for MHA + # Paddle MHA uses attn_mask, we need to convert + # ~attention_mask gives True for padded positions + key_padding_mask = ~attention_mask.astype("bool") + # Convert to float mask: 0 for valid, -inf for padded + attn_mask = key_padding_mask.astype(hidden_state.dtype) * paddle.finfo(hidden_state.dtype).min + attn_mask = attn_mask.unsqueeze([1, 2]) # [B, 1, 1, S] + + hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attn_mask) + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class Phi4MultimodalVisionModel(nn.Layer): + def __init__(self, config: Phi4MultimodalVisionConfig): + super().__init__() + self.config = config + self.embeddings = Phi4MultimodalVisionEmbeddings(config) + self.encoder = Phi4MultimodalVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values: paddle.Tensor, + patch_attention_mask: Optional[paddle.Tensor] = None, + output_hidden_states: bool = False, + ): + batch_size = pixel_values.shape[0] + if patch_attention_mask is None: + patch_attention_mask = paddle.ones( + shape=[ + batch_size, + pixel_values.shape[2] // self.config.patch_size, + pixel_values.shape[3] // self.config.patch_size, + ], + dtype="bool", + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask_flat = patch_attention_mask.reshape([batch_size, -1]) + # Create bidirectional attention mask + mask_expanded = patch_attention_mask_flat.unsqueeze([1, 2]).astype(hidden_states.dtype) + attention_mask = (1.0 - mask_expanded) * paddle.finfo(hidden_states.dtype).min + + last_hidden_state, all_hidden_states = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask_flat, + ) + + return last_hidden_state, pooled_output, all_hidden_states + + +class Phi4MultimodalImageEmbedding(nn.Layer): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.vision_config.feature_layer + self.crop_size = config.vision_config.crop_size + self.image_dim_out = config.vision_config.hidden_size + + n_patches = config.vision_config.image_size // config.vision_config.patch_size + if n_patches % 2 != 0: + self.img_processor_padding = nn.Pad2D([0, 1, 0, 1], mode="reflect") + n_patches += 1 + self.num_img_tokens = (n_patches // 2) ** 2 + + self.drop = nn.Dropout(config.embd_pdrop) + self.img_processor = Phi4MultimodalVisionModel(config.vision_config) + self.image_token_compression = nn.AvgPool2D(kernel_size=2, stride=2) + self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) + self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) + self.global_img_feature_extensor = self.create_parameter( + shape=[1, 1, self.image_dim_out], + default_initializer=nn.initializer.Constant(0.0), + ) + self.sub_img_feature_extensor = self.create_parameter( + shape=[1, 1, 1, self.image_dim_out], + default_initializer=nn.initializer.Constant(0.0), + ) + + def _repeat_sub_img_feature_extensor(self, repeat_height: int) -> paddle.Tensor: + return paddle.tile(self.sub_img_feature_extensor, repeat_times=[1, repeat_height, 1, 1]) + + def get_img_features(self, img_embeds: paddle.Tensor, attention_mask=None) -> paddle.Tensor: + _, _, all_hidden_states = self.img_processor( + img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True + ) + img_feature = all_hidden_states[self.layer_idx] + + patch_feature = img_feature + width = int(math.sqrt(patch_feature.shape[1])) + patch_feature = patch_feature.reshape([-1, width, width, patch_feature.shape[-1]]) + # convert to NCHW + patch_feature = patch_feature.transpose([0, 3, 1, 2]) + if hasattr(self, "img_processor_padding"): + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.transpose([0, 2, 3, 1]) + patch_feature = patch_feature.reshape( + [-1, patch_feature.shape[1] * patch_feature.shape[2], patch_feature.shape[-1]] + ) + return patch_feature + + def forward( + self, + input_ids: paddle.Tensor, + inputs_embeds: paddle.Tensor, + image_pixel_values: paddle.Tensor, + image_sizes: Optional[paddle.Tensor] = None, + image_attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + image_pixel_values = image_pixel_values.astype(self.img_processor.embeddings.patch_embedding.weight.dtype) + + target_dtype = self.img_projection_up.bias.dtype + + batch_size = image_pixel_values.shape[0] + + img_features = self.get_img_features( + image_pixel_values.flatten(0, 1), + attention_mask=image_attention_mask.flatten(0, 1).astype("bool") + if image_attention_mask is not None + else None, + ) + base_feat_size = int(np.sqrt(img_features.shape[1])) + img_features = img_features.reshape([batch_size, -1, base_feat_size**2, self.image_dim_out]) + image_sizes_flat = image_sizes.reshape([-1, 2]) + + output_imgs = [] + for idx in range(batch_size): + height = int(image_sizes_flat[idx, 0].item()) + width = int(image_sizes_flat[idx, 1].item()) + height_ratio = height // self.crop_size + width_ratio = width // self.crop_size + area_ratio = height_ratio * width_ratio + + global_img = img_features[idx, :1] + global_img = global_img.reshape([1, base_feat_size, base_feat_size, self.image_dim_out]) + temporary_extensor = self._repeat_sub_img_feature_extensor(base_feat_size) + global_img = paddle.concat([global_img, temporary_extensor], axis=2).reshape([1, -1, self.image_dim_out]) + + sub_img = img_features[idx, 1:] + sub_img = sub_img[:area_ratio] + sub_img = ( + sub_img.reshape([height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out]) + .transpose([0, 2, 1, 3, 4]) + .reshape([1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out]) + ) + + if image_attention_mask is not None: + reshaped_image_attention_mask = ( + image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] + .reshape([height_ratio, width_ratio, base_feat_size, base_feat_size]) + .transpose([0, 2, 1, 3]) + .reshape([1, height_ratio * base_feat_size, width_ratio * base_feat_size]) + ) + reshaped_image_attention_mask_int = reshaped_image_attention_mask.astype("int64") + useful_height = int(reshaped_image_attention_mask_int[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask_int[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temporary_extensor = self._repeat_sub_img_feature_extensor(useful_height) + else: + temporary_extensor = self._repeat_sub_img_feature_extensor(height_ratio * base_feat_size) + + sub_img = paddle.concat([sub_img, temporary_extensor], axis=2).reshape([1, -1, self.image_dim_out]) + + output_imgs.append(paddle.concat([sub_img, self.global_img_feature_extensor, global_img], axis=1)) + + img_set_tensor = [] + for output_img in output_imgs: + output_img = output_img.astype(target_dtype) + img_feature_proj = self.img_projection_up(output_img) + img_feature_proj = F.gelu(img_feature_proj) + img_feature_proj = self.img_projection_down(img_feature_proj) + img_set_tensor.append(img_feature_proj) + + merged_img_set_tensor = paddle.concat(img_set_tensor, axis=1).squeeze(0) + merged_img_set_tensor = merged_img_set_tensor.astype(inputs_embeds.dtype) + + positions = paddle.nonzero(input_ids == self.config.vision_config.image_token_id) + if positions.shape[0] > 0: + image_embeds = inputs_embeds.clone() + for i in range(positions.shape[0]): + batch_idx = positions[i, 0] + seq_idx = positions[i, 1] + image_embeds[batch_idx, seq_idx] = merged_img_set_tensor[i] + else: + image_embeds = inputs_embeds + + image_embeds = self.drop(image_embeds) + return image_embeds + + +# ======================= Audio Encoder ======================= + + +class Phi4MultimodalAudioMLP(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.act_fn = ACT2FN[config.activation] + self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.layer_norm(hidden_states) + up_states = self.gate_up_proj(hidden_states) + up_states, gate = up_states.chunk(2, axis=-1) + up_states = up_states * self.act_fn(gate) + up_states = self.dropout(up_states) + hidden_states = self.down_proj(up_states) + out = self.dropout(hidden_states) + return out + + +class Phi4MultimodalAudioAttention(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.dropout_rate + + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) + self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) + self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: paddle.Tensor, + ) -> paddle.Tensor: + input_shape = hidden_states.shape[:-1] + hidden_shape = list(input_shape) + [self.num_heads, self.head_dim] + + query_states = self.q_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + key_states = self.k_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + value_states = self.v_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + + attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) * self.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query_states.dtype) + if self.training and self.attention_dropout > 0: + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape(list(input_shape) + [-1]) + attn_output = self.o_proj(attn_output) + return attn_output + + +class Phi4MultimodalAudioDepthWiseSeparableConv1d(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): + super().__init__() + self.dw_conv = nn.Conv1D( + config.hidden_size, + config.hidden_size * config.depthwise_multiplier, + config.kernel_size, + stride=1, + padding=padding, + groups=config.hidden_size, + ) + self.pw_conv = nn.Conv1D( + config.hidden_size * config.depthwise_multiplier, config.depthwise_separable_out_channel, 1, 1, 0 + ) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + return self.pw_conv(self.dw_conv(hidden_states)) + + +class Phi4MultimodalAudioGluPointWiseConv(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.output_dim = config.ext_pw_out_channel + + self.ext_pw_conv_1d = nn.Conv1D(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) + self.glu_act = ACT2FN[config.conv_glu_type] + self.b1 = self.create_parameter( + shape=[1, config.ext_pw_out_channel, 1], + default_initializer=nn.initializer.Constant(0.0), + ) + self.b2 = self.create_parameter( + shape=[1, config.ext_pw_out_channel, 1], + default_initializer=nn.initializer.Constant(0.0), + ) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = hidden_states.transpose([0, 2, 1]) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = hidden_states[:, 0 : self.output_dim, :] + self.b1 + out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) + return out.transpose([0, 2, 1]) + + +class Phi4MultimodalAudioConvModule(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.kernel_size = config.kernel_size + + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.glu = Phi4MultimodalAudioGluPointWiseConv(config) + self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeparableConv1d(config, padding=config.kernel_size - 1) + self.act = ACT2FN[config.conv_activation] + self.ext_pw_conv_1d = nn.Conv1D(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.glu(self.layer_norm(hidden_states)) + hidden_states = self.dw_sep_conv_1d(hidden_states.transpose([0, 2, 1])) + + if self.kernel_size > 1: + hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] + + hidden_states = self.act(hidden_states) + hidden_states = self.ext_pw_conv_1d(hidden_states) + out = self.dropout(hidden_states.transpose([0, 2, 1])) + return out + + +class Phi4MultimodalAudioConformerEncoderLayer(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.feed_forward_in = Phi4MultimodalAudioMLP(config) + self.self_attn = Phi4MultimodalAudioAttention(config) + self.conv = Phi4MultimodalAudioConvModule(config) + self.feed_forward_out = Phi4MultimodalAudioMLP(config) + self.layer_norm_att = nn.LayerNorm(config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: paddle.Tensor, + ) -> paddle.Tensor: + residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) + hidden_states = self.layer_norm_att(residual) + hidden_states = residual + self.self_attn(hidden_states, attention_mask) + hidden_states = hidden_states + self.conv(hidden_states) + hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) + out = self.layer_norm(hidden_states) + return out + + +class Phi4MultimodalAudioNemoConvSubsampling(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.subsampling_factor = config.time_reduction + self.sampling_num = int(math.log2(self.subsampling_factor)) + self.act_fn = ACT2FN[config.nemo_activation] + conv_channels = config.nemo_conv_channels + + layers = [ + nn.Conv2D(1, conv_channels, kernel_size=3, stride=2, padding=1), + self.act_fn, + ] + for _ in range(self.sampling_num - 1): + layers.extend( + [ + nn.Conv2D(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), + nn.Conv2D(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), + self.act_fn, + ] + ) + + self.conv = nn.Sequential(*layers) + self.out = nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) + + def forward(self, hidden_states: paddle.Tensor, mask: Optional[paddle.Tensor]): + hidden_states = hidden_states.unsqueeze(1) + hidden_states = self.conv(hidden_states) + + b, _, t, _ = hidden_states.shape + hidden_states = self.out(hidden_states.transpose([0, 2, 1, 3]).reshape([b, t, -1])) + + if mask is None: + return hidden_states, None + + max_audio_length = hidden_states.shape[1] + feature_lens = mask.sum(1) + padding_length = paddle.ceil(feature_lens / self.subsampling_factor).astype("int64") + arange_ = paddle.arange(0, max_audio_length, dtype="int64") + pad_mask = arange_.expand([padding_length.shape[0], -1]) < padding_length.unsqueeze(1) + return hidden_states, pad_mask.unsqueeze(1) + + +class Phi4MultimodalAudioRelativeAttentionBias(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.max_distance = config.bias_max_distance + self.symmetric = config.bias_symmetric + self.num_buckets = self.max_distance + if not config.bias_symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + max_pos = x.shape[1] + context_position = paddle.arange(max_pos, dtype="int64")[:, None] + memory_position = paddle.arange(max_pos, dtype="int64")[None, :] + relative_position = memory_position - context_position + + relative_position = paddle.clip(relative_position, min=-self.max_distance, max=self.max_distance - 1) + + if self.symmetric: + bias_idx = paddle.abs(relative_position) + else: + bias_idx = relative_position + self.num_buckets // 2 + + att_bias = self.bias_values(bias_idx) + att_bias = att_bias.transpose([2, 0, 1]).unsqueeze(0) + return att_bias + + +class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.register_buffer("global_mean", paddle.zeros([config.input_size])) + self.register_buffer("global_invstd", paddle.ones([config.input_size])) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return (x - self.global_mean) * self.global_invstd + + +def unfold_tensor(tensor: paddle.Tensor, max_seq_len: int) -> paddle.Tensor: + _, T, D = tensor.shape + n_chunks = T // max_seq_len + tensor = tensor[:, : n_chunks * max_seq_len, :] + tensor = tensor.reshape([-1, max_seq_len, D]) + return tensor + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + import torch + + chunk_start_idx_t = torch.tensor(chunk_start_idx, dtype=torch.long) + start_pad = torch.nn.functional.pad(chunk_start_idx_t, (1, 0)) + end_pad = torch.nn.functional.pad(chunk_start_idx_t, (0, 1), value=x_len) + seq_range = torch.arange(0, x_len).unsqueeze(-1) + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + result = (mask_left & mask_right).numpy() + return paddle.to_tensor(result) + + +class Phi4MultimodalAudioModel(nn.Layer): + def __init__(self, config: Phi4MultimodalAudioConfig): + super().__init__() + self.config = config + self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) + self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) + self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) + self.encoders = nn.LayerList( + [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] + ) + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_start_idx = np.arange(0, seq_len, chunk_size) + if self.training and np.random.rand() > 0.5: + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) + enc_streaming_mask = enc_streaming_mask.unsqueeze(0).expand([batch_size, -1, -1]) + return enc_streaming_mask + + def forward_embeddings(self, hidden_states, masks): + seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) + if seq_len <= 0: + raise ValueError( + f"The sequence length after time reduction is invalid: {seq_len}. Your input feature is too short." + ) + batch_size = hidden_states.shape[0] + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) + + hidden_states, masks = self.embed(hidden_states, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks.astype("bool") & streaming_mask.astype("bool") + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + return hidden_states, hs_mask, masks + + def calculate_hs_mask(self, hidden_states, mask): + max_audio_length = hidden_states.shape[1] + batch_size = hidden_states.shape[0] + enc_streaming_mask = self._streaming_mask( + max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk + ) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = paddle.arange(0, max_audio_length).expand([padding_length.shape[0], -1]) < padding_length.unsqueeze( + 1 + ) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask.astype("bool") & enc_streaming_mask.astype("bool") + return pad_mask + + def forward(self, hidden_states: paddle.Tensor, mask: Optional[paddle.Tensor] = None, **kwargs): + hidden_states = self.encoder_embedding(hidden_states) + hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) + + unfolded = False + bs, seq_len, _ = hidden_states.shape + max_seq_len = 500 + if seq_len > max_seq_len: + unfolded = True + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + hidden_states = F.pad(hidden_states, [0, 0, 0, chunk_pad_size], data_format="NLC") + + hidden_states = unfold_tensor(hidden_states, max_seq_len) + masks_unfold = None + if mask is not None: + subsampled_pad_mask = mask.squeeze(1) + extra_padded = F.pad(subsampled_pad_mask.astype("float32"), [0, chunk_pad_size]) + extra_padded = extra_padded.unsqueeze(-1) + masks_unfold = unfold_tensor(extra_padded, max_seq_len) + masks_unfold = masks_unfold.squeeze(-1).astype("bool") + hs_mask = self.calculate_hs_mask(hidden_states, masks_unfold) + + relative_attention_bias = self.relative_attention_bias_layer(hidden_states) + if hs_mask is not None: + attention_mask = hs_mask.unsqueeze(1).astype(hidden_states.dtype) + relative_attention_bias + else: + attention_mask = relative_attention_bias + + for layer in self.encoders: + hidden_states = layer(hidden_states, attention_mask) + + if unfolded: + embed_dim = hidden_states.shape[-1] + hidden_states = hidden_states.reshape([bs, -1, embed_dim]) + if chunk_pad_size > 0: + hidden_states = hidden_states[:, :-chunk_pad_size, :] + + return hidden_states + + +class Phi4MultimodalAudioEmbedding(nn.Layer): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.layer_idx = config.audio_config.feature_layer + + self.drop = nn.Dropout(config.embd_pdrop) + self.encoder = Phi4MultimodalAudioModel(config.audio_config) + self.up_proj_for_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) + self.up_proj_for_vision_speech = nn.Linear( + config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size + ) + self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) + + def forward( + self, + input_ids: paddle.Tensor, + inputs_embeds: paddle.Tensor, + audio_input_features: paddle.Tensor, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + ) -> paddle.Tensor: + positions = paddle.nonzero(input_ids == self.config.audio_config.audio_token_id) + + up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech + down_proj = ( + self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech + ) + + target_dtype = up_proj.bias.dtype + audio_input_features = audio_input_features.astype(target_dtype) + + audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) + audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) + audio_encoder_hidden_states = F.gelu(audio_encoder_hidden_states) + audio_embeds = down_proj(audio_encoder_hidden_states) + + merged_audio_embeds = paddle.concat( + [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], axis=0 + ) + merged_audio_embeds = merged_audio_embeds.astype(inputs_embeds.dtype) + + audio_embeds_out = inputs_embeds.clone() + if positions.shape[0] > 0: + for i in range(positions.shape[0]): + batch_idx = positions[i, 0] + seq_idx = positions[i, 1] + audio_embeds_out[batch_idx, seq_idx] = merged_audio_embeds[i] + + audio_embeds_out = self.drop(audio_embeds_out) + return audio_embeds_out + + +# ======================= Text Decoder ======================= + + +class Phi4MultimodalRMSNorm(nn.Layer): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = self.create_parameter( + shape=[hidden_size], + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = eps + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + input_dtype = hidden_states.dtype + with paddle.amp.auto_cast(enable=False): + hidden_states = hidden_states.astype(paddle.float32) + variance = hidden_states.pow(2).mean(axis=-1, keepdim=True) + hidden_states = hidden_states * paddle.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.astype(input_dtype) + weight = self.weight.astype(input_dtype) if self.weight.dtype != input_dtype else self.weight + return weight * hidden_states + + +class Phi4MultimodalMLP(nn.Layer): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.gate_up_proj = GeneralLinear.create( + config.hidden_size, + 2 * config.intermediate_size, + has_bias=config.mlp_bias, + config=config, + tp_plan="colwise", + ) + self.down_proj = GeneralLinear.create( + config.intermediate_size, + config.hidden_size, + has_bias=config.mlp_bias, + config=config, + tp_plan="rowwise", + ) + self.activation_fn = ACT2FN[config.hidden_act] + self._init_lora(config.hidden_size, 2 * config.intermediate_size, config.intermediate_size, config.hidden_size) + + def _init_lora(self, gate_in, gate_out, down_in, down_out): + for adapter in ("vision", "speech"): + rank = getattr(self.config, f"{adapter}_lora_rank", 0) + alpha = getattr(self.config, f"{adapter}_lora_alpha", 1) + if rank <= 0: + continue + setattr(self, f"{adapter}_gate_up_lora_A", _create_lora_parameter(self, [gate_in, rank])) + setattr(self, f"{adapter}_gate_up_lora_B", _create_lora_parameter(self, [rank, gate_out])) + setattr(self, f"{adapter}_down_lora_A", _create_lora_parameter(self, [down_in, rank])) + setattr(self, f"{adapter}_down_lora_B", _create_lora_parameter(self, [rank, down_out])) + setattr(self, f"{adapter}_lora_scaling", alpha / rank) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + up_states = self.gate_up_proj(hidden_states) + adapter = _active_lora_adapter(self.config) + if adapter is not None and hasattr(self, f"{adapter}_gate_up_lora_A"): + up_states = up_states + _lora_delta( + hidden_states, + getattr(self, f"{adapter}_gate_up_lora_A"), + getattr(self, f"{adapter}_gate_up_lora_B"), + getattr(self, f"{adapter}_lora_scaling"), + ) + gate, up_states = up_states.chunk(2, axis=-1) + up_states = up_states * self.activation_fn(gate) + hidden_states = self.down_proj(up_states) + if adapter is not None and hasattr(self, f"{adapter}_down_lora_A"): + hidden_states = hidden_states + _lora_delta( + up_states, + getattr(self, f"{adapter}_down_lora_A"), + getattr(self, f"{adapter}_down_lora_B"), + getattr(self, f"{adapter}_lora_scaling"), + ) + return hidden_states + + +class Phi4MultimodalRotaryEmbedding(nn.Layer): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + + self.rope_type = config.rope_parameters.get("rope_type", "default") + rope_init_fn = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(config) + + self.register_buffer("inv_freq", inv_freq, persistable=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistable=False) + + def _match_torch_short_longrope_rounding(self, inv_freq): + rope_parameters = self.config.rope_parameters + short_factor = rope_parameters.get("short_factor", []) + if ( + self.rope_type == "longrope" + and inv_freq.shape[0] == 48 + and len(short_factor) == 48 + and all(float(factor) == 1.0 for factor in short_factor) + and float(rope_parameters.get("rope_theta", 10000.0)) == 10000.0 + ): + indices = paddle.to_tensor([7, 10, 14, 17, 20, 23, 25, 28, 31, 34, 37, 40, 43, 46], dtype="int64") + selected = paddle.gather(inv_freq, indices) + steps = [1, 1, 3, 3, 4, 4, 5, 5, 7, 7, 7, 9, 5, 6] + updates = selected + next_value = paddle.full_like(selected, float("inf")) + for step in range(max(steps)): + stepped = paddle.nextafter(updates, next_value) + mask = paddle.to_tensor([step < count for count in steps], dtype="bool") + updates = paddle.where(mask, stepped, updates) + inv_freq = paddle.scatter(inv_freq, indices, updates, overwrite=True) + return inv_freq + + def _update_longrope_inv_freq(self, position_ids, device): + if self.rope_type != "longrope": + return + + seq_len = int((paddle.max(position_ids) + 1).item()) + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len) + + original_max_position_embeddings = self.config.rope_parameters.get( + "original_max_position_embeddings", + getattr(self.config, "original_max_position_embeddings", self.config.max_position_embeddings), + ) + if seq_len <= original_max_position_embeddings: + inv_freq = self._match_torch_short_longrope_rounding(inv_freq) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistable=False) + self.register_buffer("inv_freq", inv_freq, persistable=False) + + @staticmethod + def compute_default_rope_parameters(config, seq_len=None): + base = config.rope_parameters["rope_theta"] + partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + attention_factor = 1.0 + inv_freq = 1.0 / (base ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim)) + return inv_freq, attention_factor + + def forward(self, x, position_ids): + self._update_longrope_inv_freq(position_ids, x.place) + with paddle.amp.auto_cast(enable=False): + inv_freq_expanded = self.inv_freq.astype("float32")[None, :, None].expand([position_ids.shape[0], -1, 1]) + position_ids_expanded = position_ids[:, None, :].astype("float32") + freqs = (inv_freq_expanded @ position_ids_expanded).transpose([0, 2, 1]) + emb = paddle.concat((freqs, freqs), axis=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.astype(x.dtype), sin.astype(x.dtype) + + +class Phi4MultimodalAttention(nn.Layer): + def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.sequence_parallel = config.sequence_parallel + + if config.tensor_model_parallel_size > 1: + assert self.num_heads % config.tensor_model_parallel_size == 0 + assert self.num_key_value_heads % config.tensor_model_parallel_size == 0 + self.num_heads = self.num_heads // config.tensor_model_parallel_size + self.num_key_value_heads = self.num_key_value_heads // config.tensor_model_parallel_size + + op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) + self.qkv_proj = GeneralLinear.create( + config.hidden_size, + op_size, + has_bias=config.attention_bias, + config=config, + tp_plan="colwise", + ) + self.o_proj = GeneralLinear.create( + config.num_attention_heads * self.head_dim, + config.hidden_size, + has_bias=config.attention_bias, + config=config, + tp_plan="rowwise", + ) + self._init_lora(config.hidden_size, op_size, config.num_attention_heads * self.head_dim, config.hidden_size) + + def _init_lora(self, qkv_in, qkv_out, o_in, o_out): + for adapter in ("vision", "speech"): + rank = getattr(self.config, f"{adapter}_lora_rank", 0) + alpha = getattr(self.config, f"{adapter}_lora_alpha", 1) + if rank <= 0: + continue + setattr(self, f"{adapter}_qkv_lora_A", _create_lora_parameter(self, [qkv_in, rank])) + setattr(self, f"{adapter}_qkv_lora_B", _create_lora_parameter(self, [rank, qkv_out])) + setattr(self, f"{adapter}_o_lora_A", _create_lora_parameter(self, [o_in, rank])) + setattr(self, f"{adapter}_o_lora_B", _create_lora_parameter(self, [rank, o_out])) + setattr(self, f"{adapter}_lora_scaling", alpha / rank) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + if self.sequence_parallel: + seq_len, hidden_size = hidden_states.shape + batch_size = 1 + else: + batch_size, seq_len, hidden_size = hidden_states.shape + + qkv = self.qkv_proj(hidden_states) + adapter = _active_lora_adapter(self.config) + if adapter is not None and hasattr(self, f"{adapter}_qkv_lora_A"): + qkv = qkv + _lora_delta( + hidden_states, + getattr(self, f"{adapter}_qkv_lora_A"), + getattr(self, f"{adapter}_qkv_lora_B"), + getattr(self, f"{adapter}_lora_scaling"), + ) + query_pos = self.num_heads * self.head_dim + kv_pos = self.num_key_value_heads * self.head_dim + + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + kv_pos] + value_states = qkv[..., query_pos + kv_pos :] + + if self.sequence_parallel: + query_states = query_states.reshape([seq_len, self.num_heads, self.head_dim]) + key_states = key_states.reshape([seq_len, self.num_key_value_heads, self.head_dim]) + value_states = value_states.reshape([seq_len, self.num_key_value_heads, self.head_dim]) + query_states = query_states.transpose([1, 0, 2]) + key_states = key_states.transpose([1, 0, 2]) + value_states = value_states.transpose([1, 0, 2]) + else: + query_states = query_states.reshape([batch_size, seq_len, self.num_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + key_states = key_states.reshape([batch_size, seq_len, self.num_key_value_heads, self.head_dim]).transpose( + [0, 2, 1, 3] + ) + value_states = value_states.reshape( + [batch_size, seq_len, self.num_key_value_heads, self.head_dim] + ).transpose([0, 2, 1, 3]) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + ) + + if self.sequence_parallel: + attn_output = attn_output.reshape([seq_len, -1]) + else: + attn_output = attn_output.reshape([batch_size, seq_len, -1]) + o_input = attn_output + attn_output = self.o_proj(o_input) + if adapter is not None and hasattr(self, f"{adapter}_o_lora_A"): + attn_output = attn_output + _lora_delta( + o_input, + getattr(self, f"{adapter}_o_lora_A"), + getattr(self, f"{adapter}_o_lora_B"), + getattr(self, f"{adapter}_lora_scaling"), + ) + return attn_output, attn_weights + + +class Phi4MultimodalDecoderLayer(nn.Layer): + def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx) + self.mlp = Phi4MultimodalMLP(config) + self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.config = config + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, + output_attentions: bool = False, + ) -> paddle.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + use_cache=use_cache, + ) + hidden_states = residual + self.resid_attn_dropout(hidden_states) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + return outputs + + +# ======================= Feature Embedding (bridge) ======================= + + +class Phi4MultimodalFeatureEmbedding(nn.Layer): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__() + self.config = config + self.image_token_id = config.vision_config.image_token_id + self.audio_token_id = config.audio_config.audio_token_id + self.image_embed = Phi4MultimodalImageEmbedding(config) + self.audio_embed = Phi4MultimodalAudioEmbedding(config) + + def forward( + self, + input_ids: paddle.Tensor, + inputs_embeds: paddle.Tensor, + image_pixel_values: Optional[paddle.Tensor] = None, + audio_input_features: Optional[paddle.Tensor] = None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + ) -> paddle.Tensor: + image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) + non_image_position_mask = ~image_position_mask + + image_embeds = None + audio_embeds = None + if image_pixel_values is not None and (input_ids == self.image_token_id).any(): + image_embeds = self.image_embed( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + ) + if audio_input_features is not None and (input_ids == self.audio_token_id).any(): + audio_projection_mode = "vision" if image_pixel_values is not None else "speech" + audio_embeds = self.audio_embed( + input_ids, + inputs_embeds, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + ) + + if image_embeds is not None and audio_embeds is not None: + inputs_embeds = image_embeds * image_position_mask.astype( + image_embeds.dtype + ) + audio_embeds * non_image_position_mask.astype(audio_embeds.dtype) + elif image_embeds is not None: + inputs_embeds = image_embeds + elif audio_embeds is not None: + inputs_embeds = audio_embeds + + return inputs_embeds + + +# ======================= Main Model ======================= + + +class Phi4MultimodalPreTrainedModel(PretrainedModel): + config_class = Phi4MultimodalConfig + base_model_prefix = "model" + transpose_weight_keys = [ + "qkv_proj.weight", + "o_proj.weight", + "gate_up_proj.weight", + "down_proj.weight", + "img_projection_up.weight", + "img_projection_down.weight", + "up_proj_for_speech.weight", + "down_proj_for_speech.weight", + "up_proj_for_vision_speech.weight", + "down_proj_for_vision_speech.weight", + ] + + @classmethod + def _gen_aoa_config(cls, config: Phi4MultimodalConfig): + model_prefix = "" if cls == cls.base_model_class else "model." + aoa_config = {"aoa_statements": []} + stmts = aoa_config["aoa_statements"] + + # Embedding and norm + stmts.append(f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight") + stmts.append(f"model.norm.weight -> {model_prefix}norm.weight") + if cls != cls.base_model_class: + stmts.append("model.embed_tokens.weight -> lm_head.weight") + + # Decoder layers + for layer_id in range(config.num_hidden_layers): + lp = f"model.layers.{layer_id}" + tp = f"{model_prefix}layers.{layer_id}" + stmts.append(f"{lp}.input_layernorm.weight -> {tp}.input_layernorm.weight") + stmts.append(f"{lp}.post_attention_layernorm.weight -> {tp}.post_attention_layernorm.weight") + stmts.append(f"{lp}.mlp.down_proj.base_layer.weight^T -> {tp}.mlp.down_proj.weight") + stmts.append(f"{lp}.self_attn.o_proj.base_layer.weight^T -> {tp}.self_attn.o_proj.weight") + stmts.append( + f"{lp}.self_attn.qkv_proj.base_layer.weight^T -> {tp}.self_attn.qkv_proj.weight, " + f"fused_qkv_old, num_heads={config.num_attention_heads}, " + f"num_key_value_groups={config.num_key_value_heads}, axis=1" + ) + stmts.append(f"{lp}.mlp.gate_up_proj.base_layer.weight^T -> {tp}.mlp.gate_up_proj.weight, fused_ffn") + + lora_specs = [ + ("self_attn", "qkv_proj", "qkv"), + ("self_attn", "o_proj", "o"), + ("mlp", "gate_up_proj", "gate_up"), + ("mlp", "down_proj", "down"), + ] + for adapter in ("vision", "speech"): + for block, src_proj, dst_proj in lora_specs: + rank = getattr(config, f"{adapter}_lora_rank", 0) + if rank > 0: + stmts.append( + f"{lp}.{block}.{src_proj}.lora_A.{adapter}.weight^T -> " + f"{tp}.{block}.{adapter}_{dst_proj}_lora_A" + ) + stmts.append( + f"{lp}.{block}.{src_proj}.lora_B.{adapter}.weight^T -> " + f"{tp}.{block}.{adapter}_{dst_proj}_lora_B" + ) + + # Vision encoder + vis_prefix_src = "model.embed_tokens_extend.image_embed" + vis_prefix_dst = f"{model_prefix}embed_tokens_extend.image_embed" + + stmts.append(f"{vis_prefix_src}.glb_GN -> {vis_prefix_dst}.global_img_feature_extensor") + stmts.append(f"{vis_prefix_src}.sub_GN -> {vis_prefix_dst}.sub_img_feature_extensor") + stmts.append(f"{vis_prefix_src}.img_projection.0.weight^T -> {vis_prefix_dst}.img_projection_up.weight") + stmts.append(f"{vis_prefix_src}.img_projection.0.bias -> {vis_prefix_dst}.img_projection_up.bias") + stmts.append(f"{vis_prefix_src}.img_projection.2.weight^T -> {vis_prefix_dst}.img_projection_down.weight") + stmts.append(f"{vis_prefix_src}.img_projection.2.bias -> {vis_prefix_dst}.img_projection_down.bias") + + # Vision processor layers + vp_src = f"{vis_prefix_src}.img_processor" + vp_dst = f"{vis_prefix_dst}.img_processor" + stmts.append(f"{vp_src}.embeddings.patch_embedding.weight -> {vp_dst}.embeddings.patch_embedding.weight") + stmts.append(f"{vp_src}.embeddings.patch_embedding.bias -> {vp_dst}.embeddings.patch_embedding.bias") + stmts.append(f"{vp_src}.embeddings.position_embedding.weight -> {vp_dst}.embeddings.position_embedding.weight") + stmts.append(f"{vp_src}.post_layernorm.weight -> {vp_dst}.post_layernorm.weight") + stmts.append(f"{vp_src}.post_layernorm.bias -> {vp_dst}.post_layernorm.bias") + + # Vision head (MultiheadAttentionPoolingHead) + stmts.append(f"{vp_src}.head.probe -> {vp_dst}.head.probe") + stmts.append(f"{vp_src}.head.layernorm.weight -> {vp_dst}.head.layernorm.weight") + stmts.append(f"{vp_src}.head.layernorm.bias -> {vp_dst}.head.layernorm.bias") + stmts.append(f"{vp_src}.head.mlp.fc1.weight^T -> {vp_dst}.head.mlp.fc1.weight") + stmts.append(f"{vp_src}.head.mlp.fc1.bias -> {vp_dst}.head.mlp.fc1.bias") + stmts.append(f"{vp_src}.head.mlp.fc2.weight^T -> {vp_dst}.head.mlp.fc2.weight") + stmts.append(f"{vp_src}.head.mlp.fc2.bias -> {vp_dst}.head.mlp.fc2.bias") + + for i in range(config.vision_config.num_hidden_layers): + vs = f"{vp_src}.encoder.layers.{i}" + vd = f"{vp_dst}.encoder.layers.{i}" + stmts.append(f"{vs}.layer_norm1.weight -> {vd}.layer_norm1.weight") + stmts.append(f"{vs}.layer_norm1.bias -> {vd}.layer_norm1.bias") + stmts.append(f"{vs}.layer_norm2.weight -> {vd}.layer_norm2.weight") + stmts.append(f"{vs}.layer_norm2.bias -> {vd}.layer_norm2.bias") + stmts.append(f"{vs}.self_attn.q_proj.weight^T -> {vd}.self_attn.q_proj.weight") + stmts.append(f"{vs}.self_attn.q_proj.bias -> {vd}.self_attn.q_proj.bias") + stmts.append(f"{vs}.self_attn.k_proj.weight^T -> {vd}.self_attn.k_proj.weight") + stmts.append(f"{vs}.self_attn.k_proj.bias -> {vd}.self_attn.k_proj.bias") + stmts.append(f"{vs}.self_attn.v_proj.weight^T -> {vd}.self_attn.v_proj.weight") + stmts.append(f"{vs}.self_attn.v_proj.bias -> {vd}.self_attn.v_proj.bias") + stmts.append(f"{vs}.self_attn.out_proj.weight^T -> {vd}.self_attn.out_proj.weight") + stmts.append(f"{vs}.self_attn.out_proj.bias -> {vd}.self_attn.out_proj.bias") + stmts.append(f"{vs}.mlp.fc1.weight^T -> {vd}.mlp.fc1.weight") + stmts.append(f"{vs}.mlp.fc1.bias -> {vd}.mlp.fc1.bias") + stmts.append(f"{vs}.mlp.fc2.weight^T -> {vd}.mlp.fc2.weight") + stmts.append(f"{vs}.mlp.fc2.bias -> {vd}.mlp.fc2.bias") + + # Audio encoder + aud_prefix_src = "model.embed_tokens_extend.audio_embed" + aud_prefix_dst = f"{model_prefix}embed_tokens_extend.audio_embed" + + # Audio projections + audio_projection_map = [ + ("audio_projection.speech.0", "up_proj_for_speech"), + ("audio_projection.speech.2", "down_proj_for_speech"), + ("audio_projection.vision.0", "up_proj_for_vision_speech"), + ("audio_projection.vision.2", "down_proj_for_vision_speech"), + ] + for src_proj, dst_proj in audio_projection_map: + stmts.append(f"{aud_prefix_src}.{src_proj}.weight^T -> {aud_prefix_dst}.{dst_proj}.weight") + stmts.append(f"{aud_prefix_src}.{src_proj}.bias -> {aud_prefix_dst}.{dst_proj}.bias") + + # Audio encoder internals + ae_src = f"{aud_prefix_src}.encoder" + ae_dst = f"{aud_prefix_dst}.encoder" + + stmts.append(f"{ae_src}.encoder_embedding.global_mean -> {ae_dst}.encoder_embedding.global_mean") + stmts.append(f"{ae_src}.encoder_embedding.global_invstd -> {ae_dst}.encoder_embedding.global_invstd") + stmts.append( + f"{ae_src}.relative_attention_bias_layer.bias_values.weight -> {ae_dst}.relative_attention_bias_layer.bias_values.weight" + ) + + # Nemo conv subsampling + stmts.append(f"{ae_src}.embed.conv.0.weight -> {ae_dst}.embed.conv.0.weight") + stmts.append(f"{ae_src}.embed.conv.0.bias -> {ae_dst}.embed.conv.0.bias") + for conv_idx in [2, 3, 5, 6]: + stmts.append(f"{ae_src}.embed.conv.{conv_idx}.weight -> {ae_dst}.embed.conv.{conv_idx}.weight") + stmts.append(f"{ae_src}.embed.conv.{conv_idx}.bias -> {ae_dst}.embed.conv.{conv_idx}.bias") + stmts.append(f"{ae_src}.embed.out.weight^T -> {ae_dst}.embed.out.weight") + stmts.append(f"{ae_src}.embed.out.bias -> {ae_dst}.embed.out.bias") + + # Audio conformer encoder layers + for i in range(config.audio_config.num_blocks): + al_src = f"{ae_src}.encoders.{i}" + al_dst = f"{ae_dst}.encoders.{i}" + + # feed_forward_in / feed_forward_out + for ff in ["feed_forward_in", "feed_forward_out"]: + stmts.append(f"{al_src}.{ff}.layer_norm.weight -> {al_dst}.{ff}.layer_norm.weight") + stmts.append(f"{al_src}.{ff}.layer_norm.bias -> {al_dst}.{ff}.layer_norm.bias") + stmts.append(f"{al_src}.{ff}.net.0.linear.weight^T -> {al_dst}.{ff}.gate_up_proj.weight") + stmts.append(f"{al_src}.{ff}.net.0.linear.bias -> {al_dst}.{ff}.gate_up_proj.bias") + stmts.append(f"{al_src}.{ff}.net.2.weight^T -> {al_dst}.{ff}.down_proj.weight") + stmts.append(f"{al_src}.{ff}.net.2.bias -> {al_dst}.{ff}.down_proj.bias") + + # self_attn + stmts.append(f"{al_src}.self_attn.linear_q.weight^T -> {al_dst}.self_attn.q_proj.weight") + stmts.append(f"{al_src}.self_attn.linear_q.bias -> {al_dst}.self_attn.q_proj.bias") + stmts.append(f"{al_src}.self_attn.linear_k.weight^T -> {al_dst}.self_attn.k_proj.weight") + stmts.append(f"{al_src}.self_attn.linear_k.bias -> {al_dst}.self_attn.k_proj.bias") + stmts.append(f"{al_src}.self_attn.linear_v.weight^T -> {al_dst}.self_attn.v_proj.weight") + stmts.append(f"{al_src}.self_attn.linear_v.bias -> {al_dst}.self_attn.v_proj.bias") + stmts.append(f"{al_src}.self_attn.linear_out.weight^T -> {al_dst}.self_attn.o_proj.weight") + stmts.append(f"{al_src}.self_attn.linear_out.bias -> {al_dst}.self_attn.o_proj.bias") + + # conv module + stmts.append(f"{al_src}.conv.layer_norm.weight -> {al_dst}.conv.layer_norm.weight") + stmts.append(f"{al_src}.conv.layer_norm.bias -> {al_dst}.conv.layer_norm.bias") + stmts.append(f"{al_src}.conv.glu.ext_pw_conv_1d.weight -> {al_dst}.conv.glu.ext_pw_conv_1d.weight") + stmts.append(f"{al_src}.conv.glu.ext_pw_conv_1d.bias -> {al_dst}.conv.glu.ext_pw_conv_1d.bias") + stmts.append(f"{al_src}.conv.glu.b1 -> {al_dst}.conv.glu.b1") + stmts.append(f"{al_src}.conv.glu.b2 -> {al_dst}.conv.glu.b2") + stmts.append(f"{al_src}.conv.dw_sep_conv_1d.dw_conv.weight -> {al_dst}.conv.dw_sep_conv_1d.dw_conv.weight") + stmts.append(f"{al_src}.conv.dw_sep_conv_1d.dw_conv.bias -> {al_dst}.conv.dw_sep_conv_1d.dw_conv.bias") + stmts.append(f"{al_src}.conv.dw_sep_conv_1d.pw_conv.weight -> {al_dst}.conv.dw_sep_conv_1d.pw_conv.weight") + stmts.append(f"{al_src}.conv.dw_sep_conv_1d.pw_conv.bias -> {al_dst}.conv.dw_sep_conv_1d.pw_conv.bias") + stmts.append(f"{al_src}.conv.ext_pw_conv_1d.weight -> {al_dst}.conv.ext_pw_conv_1d.weight") + stmts.append(f"{al_src}.conv.ext_pw_conv_1d.bias -> {al_dst}.conv.ext_pw_conv_1d.bias") + + # layer norms + stmts.append(f"{al_src}.layer_norm_att.weight -> {al_dst}.layer_norm_att.weight") + stmts.append(f"{al_src}.layer_norm_att.bias -> {al_dst}.layer_norm_att.bias") + stmts.append(f"{al_src}.layer_norm.weight -> {al_dst}.layer_norm.weight") + stmts.append(f"{al_src}.layer_norm.bias -> {al_dst}.layer_norm.bias") + + # Vision head attention (nn.MultiHeadAttention has different weight names in Paddle) + # The HF model uses torch.nn.MultiheadAttention with in_proj_weight/in_proj_bias + # We'll handle this mapping for the vision pooling head attention + head_src = f"{vp_src}.head.attention" + head_dst = f"{vp_dst}.head.attention" + stmts.append(f"{head_src}.in_proj_weight^T -> {head_dst}.in_proj_weight_t") + stmts.append(f"{head_src}.in_proj_bias -> {head_dst}.in_proj_bias_t") + stmts.append(f"{head_src}.out_proj.weight^T -> {head_dst}.out_proj.weight") + stmts.append(f"{head_src}.out_proj.bias -> {head_dst}.out_proj.bias") + + return aoa_config + + @classmethod + def _gen_inv_aoa_config(cls, config: Phi4MultimodalConfig): + model_prefix = "" if cls == cls.base_model_class else "model." + aoa_statements = [] + + # Decoder layers + for layer_id in range(config.num_hidden_layers): + tp = f"{model_prefix}layers.{layer_id}" + lp = f"model.layers.{layer_id}" + aoa_statements.append(f"{tp}.input_layernorm.weight -> {lp}.input_layernorm.weight") + aoa_statements.append(f"{tp}.post_attention_layernorm.weight -> {lp}.post_attention_layernorm.weight") + aoa_statements.append(f"{tp}.mlp.down_proj.weight^T -> {lp}.mlp.down_proj.weight") + aoa_statements.append(f"{tp}.self_attn.o_proj.weight^T -> {lp}.self_attn.o_proj.weight") + aoa_statements.append( + f"{tp}.self_attn.qkv_proj.weight -> {lp}.self_attn.qkv_proj.weight, " + f"fused_qkv_old, num_heads={config.num_attention_heads}, " + f"num_key_value_groups={config.num_key_value_heads}, axis=1" + ) + aoa_statements.append(f"{tp}.self_attn.qkv_proj.weight^T -> {lp}.self_attn.qkv_proj.weight") + aoa_statements.append(f"{tp}.mlp.gate_up_proj.weight^T -> {lp}.mlp.gate_up_proj.weight, fused_ffn") + + aoa_statements.append(f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight") + aoa_statements.append(f"{model_prefix}norm.weight -> model.norm.weight") + + return {"aoa_statements": aoa_statements} + + +@register_base_model +class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): + def __init__(self, config: Phi4MultimodalConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.sequence_parallel = config.sequence_parallel + + self.embed_tokens = GeneralEmbedding.create( + config=config, num_embeddings=config.vocab_size, embedding_dim=config.hidden_size + ) + self.layers = nn.LayerList( + [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Phi4MultimodalRotaryEmbedding(config) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) + + @paddle.jit.not_to_static + def recompute_training_full(self, layer_module, hidden_states, *args): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute(create_custom_forward(layer_module), hidden_states, *args) + return hidden_states + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + image_pixel_values: Optional[paddle.Tensor] = None, + image_sizes: Optional[paddle.Tensor] = None, + image_attention_mask=None, + audio_input_features: Optional[paddle.Tensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + input_mode=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + previous_adapter = getattr(self.config, "_active_lora_adapter", None) + self.config._active_lora_adapter = _lora_adapter_from_input_mode( + input_mode, + image_pixel_values=image_pixel_values, + audio_input_features=audio_input_features, + ) + + try: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens_extend( + input_ids, + inputs_embeds, + image_pixel_values=image_pixel_values, + audio_input_features=audio_input_features, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + ) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + + if position_ids is None: + position_ids = ( + paddle.arange(seq_length, dtype="int64").unsqueeze(0).expand([batch_size, -1]) + cache_length + ) + + # Create causal mask + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "batch_size": batch_size, + "seq_length": seq_length, + "cache_length": cache_length, + "attention_mask": attention_mask, + "attn_mask_startend_row_indices": attn_mask_startend_row_indices, + "prepare_decoder_attention_mask": self._prepare_decoder_attention_mask, + } + causal_mask, attn_mask_startend_row_indices = create_causal_mask_and_row_indices(**mask_kwargs) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + has_gradient = not hidden_states.stop_gradient + if ( + self.config.recompute_granularity == "full" + and self.config.recompute_method == "uniform" + and self.config.recompute_num_layers == 1 + and has_gradient + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + causal_mask, + attn_mask_startend_row_indices, + position_ids, + past_key_values, + use_cache, + position_embeddings, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states=hidden_states, + attention_mask=causal_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + finally: + self.config._active_lora_adapter = previous_adapter + + +class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Phi4MultimodalConfig): + super().__init__(config) + self.model = Phi4MultimodalModel(config) + self.vocab_size = config.vocab_size + self.lm_head = GeneralLMHead(config) + self.criterion = CriterionLayer(config) + self._apply_multimodal_freeze_config() + + def _apply_multimodal_freeze_config(self): + freeze_prefixes = [] + if getattr(self.config, "freeze_vision_model", False): + freeze_prefixes.extend( + [ + "model.embed_tokens_extend.image_embed.img_processor", + "model.embed_tokens_extend.audio_embed.encoder", + ] + ) + if getattr(self.config, "freeze_vision_projection", False): + freeze_prefixes.extend( + [ + "model.embed_tokens_extend.image_embed.img_projection_up", + "model.embed_tokens_extend.image_embed.img_projection_down", + "model.embed_tokens_extend.image_embed.global_img_feature_extensor", + "model.embed_tokens_extend.image_embed.sub_img_feature_extensor", + "model.embed_tokens_extend.audio_embed.up_proj_for_speech", + "model.embed_tokens_extend.audio_embed.down_proj_for_speech", + "model.embed_tokens_extend.audio_embed.up_proj_for_vision_speech", + "model.embed_tokens_extend.audio_embed.down_proj_for_vision_speech", + ] + ) + if getattr(self.config, "freeze_language_model", False): + freeze_prefixes.extend(["model.embed_tokens", "model.layers", "model.norm", "lm_head"]) + + freeze_multimodal_adapters = ( + getattr(self.config, "freeze_vision_model", False) + and getattr(self.config, "freeze_vision_projection", False) + and not getattr(self.config, "freeze_language_model", False) + ) + + if not freeze_prefixes: + return + + frozen = 0 + for name, param in self.named_parameters(): + if any(name.startswith(prefix) for prefix in freeze_prefixes) or ( + freeze_multimodal_adapters and "_lora_" in name + ): + param.stop_gradient = True + frozen += 1 + logger.info(f"Phi-4 multimodal freeze_config applied. Frozen parameter tensors: {frozen}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): + model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + model._apply_multimodal_freeze_config() + return model + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, + input_ids, + use_cache=True, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + image_pixel_values=None, + image_sizes=None, + image_attention_mask=None, + audio_input_features=None, + audio_embed_sizes=None, + audio_attention_mask=None, + input_mode=None, + position_ids=None, + **kwargs, + ): + if input_mode is None: + has_image = image_pixel_values is not None + has_audio = audio_input_features is not None + if has_image and has_audio: + input_mode = paddle.to_tensor([3], dtype="int64") + elif has_image: + input_mode = paddle.to_tensor([1], dtype="int64") + elif has_audio: + input_mode = paddle.to_tensor([2], dtype="int64") + + batch_size, seq_length = input_ids.shape + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").unsqueeze(0).expand([batch_size, -1]) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(axis=-1) + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "input_mode": input_mode, + } + ) + + if past_key_values is None: + model_inputs.update( + { + "image_pixel_values": image_pixel_values, + "image_sizes": image_sizes, + "image_attention_mask": image_attention_mask, + "audio_input_features": audio_input_features, + "audio_embed_sizes": audio_embed_sizes, + "audio_attention_mask": audio_attention_mask, + } + ) + + return model_inputs + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + image_pixel_values: Optional[paddle.Tensor] = None, + image_sizes: Optional[paddle.Tensor] = None, + image_attention_mask=None, + audio_input_features: Optional[paddle.Tensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + input_mode=None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + image_pixel_values=image_pixel_values, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_input_features=audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + input_mode=input_mode, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + if isinstance(loss, tuple): + loss = loss[0] + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +Phi4MultimodalForConditionalGeneration = Phi4MultimodalForCausalLM +Phi4MMForCausalLM = Phi4MultimodalForCausalLM +Phi4MMForConditionalGeneration = Phi4MultimodalForCausalLM + + +class Phi4MultimodalForCausalLMPipe(GeneralModelForCausalLMPipe): + config_class = Phi4MultimodalConfig + _gen_aoa_config = Phi4MultimodalForCausalLM._gen_aoa_config + _gen_inv_aoa_config = Phi4MultimodalForCausalLM._gen_inv_aoa_config + + +Phi4MMForCausalLMPipe = Phi4MultimodalForCausalLMPipe diff --git a/paddleformers/transformers/phi4_multimodal/processor.py b/paddleformers/transformers/phi4_multimodal/processor.py new file mode 100644 index 00000000000..fbd0340d0f2 --- /dev/null +++ b/paddleformers/transformers/phi4_multimodal/processor.py @@ -0,0 +1,150 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for Phi-4 Multimodal.""" + +import inspect +import json +import os +import re + +from ..feature_extraction_utils import BatchFeature +from ..processing_utils import ProcessingKwargs, ProcessorMixin, Unpack + + +class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "audio_kwargs": {}, + } + + +class Phi4MultimodalProcessor(ProcessorMixin): + attributes = ["image_processor", "feature_extractor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + feature_extractor_class = "AutoFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + @classmethod + def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor_dict=None, **kwargs): + from ..auto.tokenizer import AutoTokenizer + from .feature_extraction import Phi4MultimodalFeatureExtractor + from .image_processor import Phi4MultimodalImageProcessor + + preprocessor_config = {} + preprocessor_config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") + if os.path.exists(preprocessor_config_path): + with open(preprocessor_config_path, encoding="utf-8") as f: + preprocessor_config = json.load(f) + + def _filter_config(processor_cls): + valid_keys = { + key + for key, value in inspect.signature(processor_cls.__init__).parameters.items() + if key != "self" + and value.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + } + return {key: value for key, value in preprocessor_config.items() if key in valid_keys} + + image_processor = Phi4MultimodalImageProcessor(**_filter_config(Phi4MultimodalImageProcessor)) + feature_extractor = Phi4MultimodalFeatureExtractor(**_filter_config(Phi4MultimodalFeatureExtractor)) + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + return [image_processor, feature_extractor, tokenizer] + + def __init__( + self, + image_processor=None, + feature_extractor=None, + tokenizer=None, + chat_template=None, + audio_processor=None, + **kwargs, + ): + if feature_extractor is None: + feature_extractor = audio_processor + self.image_token = getattr(tokenizer, "image_token", "<|image|>") + self.audio_token = getattr(tokenizer, "audio_token", "<|audio|>") + self.image_token_id = getattr(tokenizer, "image_token_id", None) + if self.image_token_id is None and tokenizer is not None: + self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) + self.audio_token_id = getattr(tokenizer, "audio_token_id", None) + if self.audio_token_id is None and tokenizer is not None: + self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) + super().__init__(image_processor, feature_extractor, tokenizer, chat_template=chat_template, **kwargs) + self.audio_processor = self.feature_extractor + + def __call__(self, text, images=None, audio=None, **kwargs: Unpack[Phi4MultimodalProcessorKwargs]) -> BatchFeature: + output_kwargs = self._merge_kwargs( + Phi4MultimodalProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) if images is not None else {} + audio_inputs = self.audio_processor(audio, **output_kwargs["audio_kwargs"]) if audio is not None else {} + + num_img_tokens = image_inputs.pop("num_img_tokens", []) + audio_embed_sizes = audio_inputs.get("audio_embed_sizes", []) + if hasattr(audio_embed_sizes, "numpy"): + audio_embed_sizes = audio_embed_sizes.numpy().tolist() + elif not isinstance(audio_embed_sizes, list): + audio_embed_sizes = list(audio_embed_sizes) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) or (len(text) > 0 and not isinstance(text[0], str)): + raise TypeError("Invalid input text. Please provide a string or a list of strings.") + + concatenated_prompt = "".join(text) + if concatenated_prompt.count(self.image_token) != len(num_img_tokens): + raise ValueError( + "You should add as many image tokens `<|image|>` in your prompt as images passed to the processor. " + f"Input contains {concatenated_prompt.count(self.image_token)} tokens != {len(num_img_tokens)} images." + ) + if concatenated_prompt.count(self.audio_token) != len(audio_embed_sizes): + raise ValueError( + "You should add as many audio tokens `<|audio|>` in your prompt as audios passed to the processor. " + f"Input contains {concatenated_prompt.count(self.audio_token)} tokens != {len(audio_embed_sizes)} audios." + ) + + image_count_iter = iter(num_img_tokens) + audio_count_iter = iter(audio_embed_sizes) + processed_text = [ + re.sub(re.escape(self.image_token), lambda _: self.image_token * int(next(image_count_iter)), sample) + for sample in text + ] + processed_text = [ + re.sub(re.escape(self.audio_token), lambda _: self.audio_token * int(next(audio_count_iter)), sample) + for sample in processed_text + ] + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + text_inputs = self.tokenizer(processed_text, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(processed_text, text_inputs, modalities=["image", "audio"]) + + if images is not None and audio is not None: + input_mode = 3 + elif images is not None: + input_mode = 1 + elif audio is not None: + input_mode = 2 + else: + input_mode = 0 + + return BatchFeature( + data={**text_inputs, **image_inputs, **audio_inputs, "input_mode": [input_mode]}, + tensor_type=return_tensors, + ) + + +__all__ = ["Phi4MultimodalProcessor"]