From 34f844a298bb2b72109b8709c5e1458f79fb0c29 Mon Sep 17 00:00:00 2001
From: yicycyc <1258085915@qq.com>
Date: Wed, 10 Jun 2026 07:10:10 +0000
Subject: [PATCH] Add Phi-4 multimodal support
---
README.md | 9 +-
docs/zh/model_capability.md | 2 +
paddleformers/datasets/SFTDataset.py | 23 +
paddleformers/datasets/collate.py | 49 +
paddleformers/datasets/template/mm_plugin.py | 85 +
paddleformers/datasets/template/template.py | 18 +
paddleformers/transformers/__init__.py | 22 +
.../transformers/auto/configuration.py | 10 +
.../transformers/auto/feature_extraction.py | 4 +
.../transformers/auto/image_processing.py | 4 +
paddleformers/transformers/auto/modeling.py | 10 +-
paddleformers/transformers/auto/processing.py | 4 +
.../transformers/phi4_multimodal/__init__.py | 53 +
.../phi4_multimodal/configuration.py | 286 +++
.../phi4_multimodal/feature_extraction.py | 189 ++
.../phi4_multimodal/image_processor.py | 285 +++
.../transformers/phi4_multimodal/modeling.py | 1955 +++++++++++++++++
.../transformers/phi4_multimodal/processor.py | 150 ++
18 files changed, 3155 insertions(+), 3 deletions(-)
create mode 100644 paddleformers/transformers/phi4_multimodal/__init__.py
create mode 100644 paddleformers/transformers/phi4_multimodal/configuration.py
create mode 100644 paddleformers/transformers/phi4_multimodal/feature_extraction.py
create mode 100644 paddleformers/transformers/phi4_multimodal/image_processor.py
create mode 100644 paddleformers/transformers/phi4_multimodal/modeling.py
create mode 100644 paddleformers/transformers/phi4_multimodal/processor.py
diff --git a/README.md b/README.md
index 76e7ea002a5..223642ab27f 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform
结合业界主流优化方法与飞桨在业务实践中积累的高效特性,PaddleFormers 致力于打造**高性能、低资源占用**的训练体验,帮助用户高效便捷地完成大模型训练,而无需关注底层复杂的优化细节。
## 🆕最新更新
-* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先Megatron-LM 6%。
+* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先 Megatron-LM 6%。
* 2026.01.21 - PaddleFomers v1.0版本发布啦!我们提供了针对 LLM 和 VLM 等模型的训练能力,针对 DeepSeek-V3模型和 GLM-4.5-Air 等重点模型,我们实现了极致性能优化(训练性能明显超越 Megatron-LM )。针对 PaddleOCR-VL,我们在昆仑芯 P800、天数天垓150等国产计算芯片上进行了适配,更好的满足国内用户需求。
## ✨特性
@@ -102,11 +102,16 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform
- | VLM |
+ VLM |
🏛️ERNIE-4.5-VL |
baidu/ERNIE-4.5-VL-28B-A3B-Base-PT、baidu/ERNIE-4.5-VL-28B-A3B-PT、baidu/ERNIE-4.5-VL-424B-A47B-Base-PT、baidu/ERNIE-4.5-VL-424B-A47B-PT、baidu/ERNIE-4.5-VL-28B-A3B-Thinking |
ernie_vl、ernie_vl_nothink |
+
+ | Phi-4-multimodal |
+ microsoft/Phi-4-multimodal-instruct |
+ phi4_multimodal |
+
| 🏛️PaddleOCR-VL |
PaddlePaddle/PaddleOCR-VL |
diff --git a/docs/zh/model_capability.md b/docs/zh/model_capability.md
index 3b402f9d4b2..ba5beca392f 100644
--- a/docs/zh/model_capability.md
+++ b/docs/zh/model_capability.md
@@ -12,6 +12,7 @@
|Qwen3|✓|✓|✓|✓|✓|
|Qwen3-Next|✓|✓|✓|✓|✓|
|🏛️ERNIE-4.5-VL|x|✓|✓|x|x|
+|Phi-4-multimodal|x|✓|✓|x|x|
|🏛️PaddleOCR-VL|x|✓|✓|x|x|
|Qwen2.5-VL|x|✓|✓|x|x|
|Qwen3-VL|x|✓|✓|x|x|
@@ -30,6 +31,7 @@
|Qwen3|✓|✓|✓|✓|✓|✓|
|Qwen3-Next|✓|✓|✓|x|✓|✓|
|🏛️ERNIE-4.5-VL|✓|✓|✓|x|✓|✓|
+|Phi-4-multimodal|x|x|-|x|✓|✓|
|🏛️PaddleOCR-VL|x|x|-|x|✓|✓|
|Qwen2.5-VL|✓|x|-|x|✓|✓|
|Qwen3-VL|x|x|✓|x|✓|✓|
diff --git a/paddleformers/datasets/SFTDataset.py b/paddleformers/datasets/SFTDataset.py
index 186d4ec0780..fb049da23eb 100644
--- a/paddleformers/datasets/SFTDataset.py
+++ b/paddleformers/datasets/SFTDataset.py
@@ -618,6 +618,7 @@ def _process_pretraining_sequence(self, example, actual_example_num):
# label shift
labels = labels[1:] + [-100]
+ labels = self._mask_mm_token_labels(tokens, labels)
pos_ids = list(range(len(tokens))) # only pure text, mm_position_ids will be reconstructed in collate.py
@@ -808,6 +809,7 @@ def _process_sft_sequence(self, example, actual_example_num):
# label shift
labels = labels[1:] + [-100]
+ labels = self._mask_mm_token_labels(tokens, labels)
pos_ids = list(range(len(tokens)))
@@ -887,6 +889,27 @@ def _truncate(
return input_ids, labels
+ def _mask_mm_token_labels(self, tokens, labels):
+ if not self.use_template or self.template_backend == "jinja" or self.template is None:
+ return labels
+
+ mm_plugin = getattr(self.template, "mm_plugin", None)
+ if not getattr(mm_plugin, "mask_mm_token_labels", False):
+ return labels
+
+ mm_tokens = [token for token in [mm_plugin.image_token, mm_plugin.audio_token] if token is not None]
+ if not mm_tokens:
+ return labels
+
+ mm_token_ids = self.tokenizer.convert_tokens_to_ids(mm_tokens)
+ if not isinstance(mm_token_ids, list):
+ mm_token_ids = [mm_token_ids]
+ mm_token_ids = {token_id for token_id in mm_token_ids if token_id is not None}
+ for i, token in enumerate(tokens):
+ if token in mm_token_ids or labels[i] in mm_token_ids:
+ labels[i] = -100
+ return labels
+
def _encode_truncated(self, input_ids, labels):
length = self._get_length(input_ids, labels)
if self.max_seq_len is not None and length > self.max_seq_len:
diff --git a/paddleformers/datasets/collate.py b/paddleformers/datasets/collate.py
index 30be4d57c82..ee6779e9e0e 100644
--- a/paddleformers/datasets/collate.py
+++ b/paddleformers/datasets/collate.py
@@ -625,6 +625,13 @@ def mm_collate_fn(
input_keys.append("video_grid_thw")
input_keys.append("input_features")
input_keys.append("feature_attention_mask")
+ input_keys.append("image_pixel_values")
+ input_keys.append("image_sizes")
+ input_keys.append("image_attention_mask")
+ input_keys.append("audio_input_features")
+ input_keys.append("audio_embed_sizes")
+ input_keys.append("audio_attention_mask")
+ input_keys.append("input_mode")
if training_args.num_nextn_predict_layers > 0:
input_keys.append("nbatch_pack_offset")
@@ -652,6 +659,13 @@ def mm_collate_fn(
video_grid_thw = []
input_features = []
feature_attention_mask = []
+ image_pixel_values = []
+ image_sizes = []
+ image_attention_mask = []
+ audio_input_features = []
+ audio_embed_sizes = []
+ audio_attention_mask = []
+ input_mode = []
for seq in batch_sequence:
original_token_ids.append(seq.token_ids)
mm_inputs = seq.mm_inputs
@@ -667,6 +681,20 @@ def mm_collate_fn(
input_features.append(mm_inputs["input_features"])
if "feature_attention_mask" in mm_inputs:
feature_attention_mask.append(mm_inputs["feature_attention_mask"])
+ if "image_pixel_values" in mm_inputs:
+ image_pixel_values.append(mm_inputs["image_pixel_values"])
+ if "image_sizes" in mm_inputs:
+ image_sizes.append(mm_inputs["image_sizes"])
+ if "image_attention_mask" in mm_inputs:
+ image_attention_mask.append(mm_inputs["image_attention_mask"])
+ if "audio_input_features" in mm_inputs:
+ audio_input_features.append(mm_inputs["audio_input_features"])
+ if "audio_embed_sizes" in mm_inputs:
+ audio_embed_sizes.append(mm_inputs["audio_embed_sizes"])
+ if "audio_attention_mask" in mm_inputs:
+ audio_attention_mask.append(mm_inputs["audio_attention_mask"])
+ if "input_mode" in mm_inputs:
+ input_mode.append(mm_inputs["input_mode"])
if get_rope_func is not None:
filtered_args = {k: paddle.to_tensor(mm_inputs[k]) for k in func_params if k in mm_inputs}
total_input_ids = paddle.to_tensor([seq.token_ids])
@@ -712,6 +740,20 @@ def mm_collate_fn(
input_features = paddle.concat(input_features, axis=0)
if len(feature_attention_mask) > 0:
feature_attention_mask = paddle.concat(feature_attention_mask, axis=0)
+ if len(image_pixel_values) > 0:
+ image_pixel_values = paddle.concat(image_pixel_values, axis=0)
+ if len(image_sizes) > 0:
+ image_sizes = paddle.concat(image_sizes, axis=0)
+ if len(image_attention_mask) > 0:
+ image_attention_mask = paddle.concat(image_attention_mask, axis=0)
+ if len(audio_input_features) > 0:
+ audio_input_features = paddle.concat(audio_input_features, axis=0)
+ if len(audio_embed_sizes) > 0:
+ audio_embed_sizes = paddle.concat(audio_embed_sizes, axis=0)
+ if len(audio_attention_mask) > 0:
+ audio_attention_mask = paddle.concat(audio_attention_mask, axis=0)
+ if len(input_mode) > 0:
+ input_mode = paddle.concat(input_mode, axis=0)
if get_token_type_func is not None: # ernie45vl
bs_idx_in_rope = 0
padded_position_ids = padded_position_ids.transpose([1, 2, 0])
@@ -736,6 +778,13 @@ def mm_collate_fn(
video_grid_thw,
input_features,
feature_attention_mask,
+ image_pixel_values,
+ image_sizes,
+ image_attention_mask,
+ audio_input_features,
+ audio_embed_sizes,
+ audio_attention_mask,
+ input_mode,
]
)
diff --git a/paddleformers/datasets/template/mm_plugin.py b/paddleformers/datasets/template/mm_plugin.py
index 3b5c124237f..83a116a90b3 100644
--- a/paddleformers/datasets/template/mm_plugin.py
+++ b/paddleformers/datasets/template/mm_plugin.py
@@ -392,6 +392,90 @@ def get_mm_inputs(
return self._get_mm_inputs(images, videos, audios, processor, **kwargs)
+@dataclass
+class Phi4MultimodalPlugin(BasePlugin):
+ mask_mm_token_labels: bool = True
+
+ @staticmethod
+ def _as_int(value):
+ if hasattr(value, "item"):
+ value = value.item()
+ return int(value)
+
+ @override
+ def _validate_input(self, processor, images, videos, audios) -> None:
+ if len(videos) != 0:
+ raise ValueError("Phi-4 multimodal does not support video input in this template.")
+ if len(images) != 0:
+ if processor is None:
+ raise ValueError("Processor was not found for Phi-4 multimodal image input.")
+ if getattr(processor, "image_processor", None) is None:
+ raise ValueError("Image processor was not found for Phi-4 multimodal image input.")
+ if len(audios) != 0:
+ if processor is None:
+ raise ValueError("Processor was not found for Phi-4 multimodal audio input.")
+ if getattr(processor, "feature_extractor", None) is None:
+ raise ValueError("Audio feature extractor was not found for Phi-4 multimodal audio input.")
+
+ @override
+ def _get_mm_inputs(self, images, videos, audios, processor, **kwargs):
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ mm_inputs.update(processor.image_processor(images, return_tensors="pd"))
+
+ if len(audios) != 0:
+ feature_extractor = getattr(processor, "feature_extractor", None)
+ sampling_rate = getattr(
+ feature_extractor, "sampling_rate", getattr(processor, "audio_sampling_rate", 16000)
+ )
+ audios = self._regularize_audios(audios, sampling_rate=sampling_rate)["audios"]
+ mm_inputs.update(
+ feature_extractor(
+ audios,
+ sampling_rate=sampling_rate,
+ return_attention_mask=True,
+ return_tensors="pd",
+ )
+ )
+
+ if len(images) != 0 or len(audios) != 0:
+ input_mode = 3 if len(images) != 0 and len(audios) != 0 else 1 if len(images) != 0 else 2
+ mm_inputs["input_mode"] = paddle.to_tensor([input_mode], dtype="int64")
+
+ return mm_inputs
+
+ @override
+ def process_messages(self, messages, images, videos, audios, mm_inputs, processor):
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ messages = deepcopy(messages)
+ num_img_tokens = mm_inputs.get("num_img_tokens", [])
+ audio_embed_sizes = mm_inputs.get("audio_embed_sizes", [])
+ num_image_tokens, num_audio_tokens = 0, 0
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ image_seqlen = self._as_int(num_img_tokens[num_image_tokens]) if self.expand_mm_tokens else 1
+ content = content.replace(IMAGE_PLACEHOLDER, self.image_token * image_seqlen, 1)
+ num_image_tokens += 1
+
+ while AUDIO_PLACEHOLDER in content:
+ audio_seqlen = self._as_int(audio_embed_sizes[num_audio_tokens]) if self.expand_mm_tokens else 1
+ content = content.replace(AUDIO_PLACEHOLDER, self.audio_token * audio_seqlen, 1)
+ num_audio_tokens += 1
+
+ message["content"] = content
+
+ self.masked_tokens = [token for token in [self.image_token, self.audio_token] if token is not None]
+ return messages
+
+
@dataclass
class PaddleOCRVLPlugin(BasePlugin):
image_bos_token: str = "<|IMAGE_START|>"
@@ -1496,6 +1580,7 @@ def process_messages(
PLUGINS = {
"base": BasePlugin,
+ "phi4_multimodal": Phi4MultimodalPlugin,
"ernie_vl": ErnieVLPlugin,
"qwen2_vl": Qwen2VLPlugin,
"paddleocr_vl": PaddleOCRVLPlugin,
diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py
index 1c3757e34ad..c40c615b47b 100644
--- a/paddleformers/datasets/template/template.py
+++ b/paddleformers/datasets/template/template.py
@@ -977,6 +977,24 @@ def _get_gpt_oss_prefix():
chat_sep="<|im_end|>",
)
+register_template(
+ name="phi4_multimodal",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
+ format_observation=StringFormatter(
+ slots=["<|user|>\n\n{{content}}\n<|end|>\n<|assistant|>\n"]
+ ),
+ suffix=["<|end|>"],
+ chat_sep="<|end|>\n",
+ auto_add_bos=False,
+ mm_plugin=get_mm_plugin(
+ name="phi4_multimodal",
+ image_token="<|endoftext10|>",
+ audio_token="<|endoftext11|>",
+ ),
+)
+
register_template(
name="glm_ocr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}\n"]),
diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py
index feeff603775..014ea816320 100644
--- a/paddleformers/transformers/__init__.py
+++ b/paddleformers/transformers/__init__.py
@@ -314,6 +314,27 @@
"phi3.configuration": ["Phi3Config"],
"phi3.tokenizer": ["Phi3Tokenizer"],
"phi3.modeling": ["Phi3Model", "Phi3ForCausalLM", "Phi3ForCausalLMPipe"],
+ "phi4_multimodal.configuration": [
+ "Phi4MultimodalAudioConfig",
+ "Phi4MultimodalConfig",
+ "Phi4MultimodalVisionConfig",
+ ],
+ "phi4_multimodal.feature_extraction": ["Phi4MultimodalFeatureExtractor"],
+ "phi4_multimodal.image_processor": ["Phi4MultimodalImageProcessor"],
+ "phi4_multimodal.modeling": [
+ "Phi4Multimodal",
+ "Phi4MultimodalAudio",
+ "Phi4MultimodalAudioModel",
+ "Phi4MultimodalAudioPreTrainedModel",
+ "Phi4MultimodalForCausalLM",
+ "Phi4MultimodalForConditionalGeneration",
+ "Phi4MultimodalModel",
+ "Phi4MultimodalPreTrainedModel",
+ "Phi4MultimodalVision",
+ "Phi4MultimodalVisionModel",
+ "Phi4MultimodalVisionPreTrainedModel",
+ ],
+ "phi4_multimodal.processor": ["Phi4MultimodalProcessor"],
"glm4v_moe.configuration": ["Glm4vMoeConfig", "Glm4vMoeTextConfig", "Glm4vMoeVisionConfig"],
"glm4v_moe.modeling": [
"Glm4vMoeForConditionalGeneration",
@@ -408,6 +429,7 @@
from .minimax_m2 import *
from .gpt_oss import *
from .phi3 import *
+ from .phi4_multimodal import *
from .gemma3_text import *
from .glm_ocr import *
else:
diff --git a/paddleformers/transformers/auto/configuration.py b/paddleformers/transformers/auto/configuration.py
index fc8c594f4cb..2a3cc026b01 100644
--- a/paddleformers/transformers/auto/configuration.py
+++ b/paddleformers/transformers/auto/configuration.py
@@ -56,6 +56,10 @@
("minimax_m2", "MiniMaxM2Config"),
("gpt_oss", "GptOssConfig"),
("phi3", "Phi3Config"),
+ ("phi4mm", "Phi4MultimodalConfig"),
+ ("phi4_multimodal", "Phi4MultimodalConfig"),
+ ("phi4_multimodal_audio", "Phi4MultimodalAudioConfig"),
+ ("phi4_multimodal_vision", "Phi4MultimodalVisionConfig"),
("gemma3_text", "Gemma3TextConfig"),
("glm4v_moe", "Glm4vMoeConfig"),
("glm_ocr", "GlmOcrConfig"),
@@ -87,6 +91,9 @@
("qwen3_vl_moe", "Qwen3VLMoe"),
("qwen3_vl_moe_text", "Qwen3VLMoeText"),
("glm_ocr", "GlmOcrForConditionalGeneration"),
+ ("phi4_multimodal", "Phi4Multimodal"),
+ ("phi4_multimodal_audio", "Phi4MultimodalAudio"),
+ ("phi4_multimodal_vision", "Phi4MultimodalVision"),
("qwen3_5_moe", "Qwen3_5MoEForConditionalGeneration"),
("qwen3_5", "Qwen3_5ForConditionalGeneration"),
]
@@ -100,6 +107,9 @@
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
[
("qwen2_5_vl_text", "qwen2_5_vl"),
+ ("phi4_multimodal_audio", "phi4_multimodal"),
+ ("phi4_multimodal_vision", "phi4_multimodal"),
+ ("phi4mm", "phi4_multimodal"),
("qwen3_vl_text", "qwen3_vl"),
("qwen3_vl_moe_text", "qwen3_vl_moe"),
]
diff --git a/paddleformers/transformers/auto/feature_extraction.py b/paddleformers/transformers/auto/feature_extraction.py
index 555237c4142..33252ffedce 100644
--- a/paddleformers/transformers/auto/feature_extraction.py
+++ b/paddleformers/transformers/auto/feature_extraction.py
@@ -39,6 +39,7 @@
FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
+ ("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
("whisper", "WhisperFeatureExtractor"),
]
)
@@ -56,6 +57,9 @@ def safe_load_json_file(json_file: str):
def feature_extractor_class_from_name(class_name: str):
+ if class_name == "Phi4MMAudioFeatureExtractor":
+ class_name = "Phi4MultimodalFeatureExtractor"
+
for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if class_name in extractors:
module_name = model_type_to_module_name(module_name)
diff --git a/paddleformers/transformers/auto/image_processing.py b/paddleformers/transformers/auto/image_processing.py
index 4244259c0e1..39fc1307141 100644
--- a/paddleformers/transformers/auto/image_processing.py
+++ b/paddleformers/transformers/auto/image_processing.py
@@ -55,6 +55,7 @@
"glm4v_moe": ("Glm4vImageProcessor", "Glm4vImageProcessorFast"),
"kimi_k25": ("KimiK25VisionProcessor"),
"paddleocr_vl": ("PaddleOCRVLImageProcessor"),
+ "phi4_multimodal": ("Phi4MultimodalImageProcessor"),
"qwen2_5_vl": ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast"),
"qwen2_vl": ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast"),
"qwen3_vl": ("Qwen3VLImageProcessor", "Qwen3VLImageProcessorFast"),
@@ -68,6 +69,9 @@
def get_image_processor_class_from_name(class_name: str):
+ if class_name == "Phi4MMImageProcessor":
+ class_name = "Phi4MultimodalImageProcessor"
+
if class_name == "BaseImageProcessorFast":
return BaseImageProcessorFast
diff --git a/paddleformers/transformers/auto/modeling.py b/paddleformers/transformers/auto/modeling.py
index f450dc95656..67bb875e96e 100644
--- a/paddleformers/transformers/auto/modeling.py
+++ b/paddleformers/transformers/auto/modeling.py
@@ -75,6 +75,10 @@
("MiniMaxM2", "minimax_m2"),
("GptOss", "gpt_oss"),
("Phi3", "phi3"),
+ ("Phi4MM", "phi4_multimodal"),
+ ("Phi4Multimodal", "phi4_multimodal"),
+ ("Phi4MultimodalAudio", "phi4_multimodal"),
+ ("Phi4MultimodalVision", "phi4_multimodal"),
("Gemma3", "gemma3_text"),
("Glm4vMoe", "glm4v_moe"),
("GlmOcr", "glm_ocr"),
@@ -82,7 +86,11 @@
)
MAPPING_SPACIAL_KEY = OrderedDict(
- [("Gemma3", "Gemma3Text"), ("Ernie4_5_VLMoe", "Ernie4_5_VLMoeForConditionalGeneration")]
+ [
+ ("Gemma3", "Gemma3Text"),
+ ("Ernie4_5_VLMoe", "Ernie4_5_VLMoeForConditionalGeneration"),
+ ("Phi4MM", "Phi4Multimodal"),
+ ]
)
CONFIGURATION_MODEL_MAPPING = OrderedDict([((), "Gemma3TextModel")])
diff --git a/paddleformers/transformers/auto/processing.py b/paddleformers/transformers/auto/processing.py
index bca898e350d..40f8d1bf18d 100644
--- a/paddleformers/transformers/auto/processing.py
+++ b/paddleformers/transformers/auto/processing.py
@@ -49,6 +49,7 @@
PROCESSOR_MAPPING_NAMES = OrderedDict(
[
("kimi_k25", "KimiK25Processor"),
+ ("phi4_multimodal", "Phi4MultimodalProcessor"),
("qwen2_5_vl", "Qwen2_5_VLProcessor"),
("qwen3_vl", "Qwen3VLProcessor"),
("qwen2_vl", "Qwen2VLProcessor"),
@@ -64,6 +65,9 @@
def processor_class_from_name(class_name: str):
+ if class_name == "Phi4MMProcessor":
+ class_name = "Phi4MultimodalProcessor"
+
for module_name, extractors in PROCESSOR_MAPPING_NAMES.items():
if class_name in extractors:
module_name = model_type_to_module_name(module_name)
diff --git a/paddleformers/transformers/phi4_multimodal/__init__.py b/paddleformers/transformers/phi4_multimodal/__init__.py
new file mode 100644
index 00000000000..1e60a074bd5
--- /dev/null
+++ b/paddleformers/transformers/phi4_multimodal/__init__.py
@@ -0,0 +1,53 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Package"""
+import sys
+from typing import TYPE_CHECKING
+
+from ...utils.lazy_import import _LazyModule
+
+import_structure = {
+ "configuration": [
+ "Phi4MultimodalConfig",
+ "Phi4MultimodalVisionConfig",
+ "Phi4MultimodalAudioConfig",
+ ],
+ "feature_extraction": ["Phi4MultimodalFeatureExtractor"],
+ "image_processor": ["Phi4MultimodalImageProcessor"],
+ "modeling": [
+ "Phi4MultimodalPreTrainedModel",
+ "Phi4MultimodalModel",
+ "Phi4MultimodalForCausalLM",
+ "Phi4MultimodalForConditionalGeneration",
+ "Phi4MultimodalForCausalLMPipe",
+ "Phi4MMForCausalLM",
+ "Phi4MMForConditionalGeneration",
+ "Phi4MMForCausalLMPipe",
+ ],
+ "processor": ["Phi4MultimodalProcessor"],
+}
+
+if TYPE_CHECKING:
+ from .configuration import *
+ from .feature_extraction import *
+ from .image_processor import *
+ from .modeling import *
+ from .processor import *
+else:
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ import_structure,
+ module_spec=__spec__,
+ )
diff --git a/paddleformers/transformers/phi4_multimodal/configuration.py b/paddleformers/transformers/phi4_multimodal/configuration.py
new file mode 100644
index 00000000000..0d71c9e6948
--- /dev/null
+++ b/paddleformers/transformers/phi4_multimodal/configuration.py
@@ -0,0 +1,286 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Phi-4-Multimodal configuration."""
+
+import math
+
+from ..configuration_utils import PretrainedConfig
+
+
+def _convert_phi4mm_config(config_dict, with_lora_adapters=True):
+ config = dict(config_dict)
+
+ config.pop("_name_or_path", None)
+ config.pop("architectures", None)
+ config.pop("auto_map", None)
+ vision_lora = config.pop("vision_lora", None) or {}
+ speech_lora = config.pop("speech_lora", None) or {}
+ config.pop("transformers_version", None)
+ config.pop("_attn_implementation", None)
+ config.pop("torch_dtype", None)
+ config.pop("model_type", None)
+
+ embd_layer = config.pop("embd_layer")
+ audio_embd_layer = embd_layer["audio_embd_layer"]
+ vision_embd_layer = embd_layer["image_embd_layer"]
+
+ audio_config = config.pop("audio_processor")["config"]
+ audio_config.pop("activation_checkpointing", None)
+ audio_config.pop("cnn_layer_norm", None)
+ audio_config.pop("input_layer", None)
+ audio_config.pop("batch_norm", None)
+ audio_config.pop("encoder_embedding_config", None)
+ audio_config.pop("ext_pw_kernel_size", None)
+ audio_config.pop("bias_in_glu", None)
+ audio_config.pop("causal", None)
+
+ audio_config["hidden_size"] = audio_config.pop("attention_dim")
+ audio_config["num_attention_heads"] = audio_config.pop("attention_heads")
+ audio_config["intermediate_size"] = audio_config.pop("linear_units")
+ audio_config["nemo_conv_channels"] = audio_config.pop("nemo_conv_settings")["conv_channels"]
+ audio_config["bias_max_distance"] = audio_config.pop("relative_attention_bias_args")["t5_bias_max_distance"]
+ audio_config["downsample_rate"] = audio_embd_layer["downsample_rate"]
+ audio_config.pop("depthwise_seperable_out_channel", None)
+
+ if "depthwise_separable_out_channel" not in audio_config:
+ audio_config["depthwise_separable_out_channel"] = audio_config.get("ext_pw_out_channel")
+
+ config["audio_config"] = audio_config
+ config["vision_config"] = {"crop_size": vision_embd_layer["crop_size"]}
+ config["eos_token_id"] = [199999, 200020]
+
+ if with_lora_adapters:
+ config.update(
+ {
+ "vision_lora_rank": vision_lora.get("r", 0),
+ "vision_lora_alpha": vision_lora.get("lora_alpha", 1),
+ "speech_lora_rank": speech_lora.get("r", 0),
+ "speech_lora_alpha": speech_lora.get("lora_alpha", 1),
+ }
+ )
+ return config
+
+
+class Phi4MultimodalVisionConfig(PretrainedConfig):
+ model_type = "phi4_multimodal_vision"
+
+ def __init__(
+ self,
+ hidden_size=1152,
+ intermediate_size=4304,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=448,
+ patch_size=14,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ crop_size=448,
+ image_token_id=200010,
+ feature_layer=-2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.attention_dropout = attention_dropout
+ self.crop_size = crop_size
+ self.image_token_id = image_token_id
+ self.feature_layer = feature_layer
+
+
+class Phi4MultimodalAudioConfig(PretrainedConfig):
+ model_type = "phi4_multimodal_audio"
+
+ def __init__(
+ self,
+ hidden_size=1024,
+ intermediate_size=1536,
+ num_blocks=24,
+ num_attention_heads=16,
+ activation="swish",
+ chunk_size=-1,
+ left_chunk=18,
+ dropout_rate=0.0,
+ ext_pw_out_channel=1024,
+ depthwise_separable_out_channel=1024,
+ depthwise_multiplier=1,
+ kernel_size=3,
+ conv_activation="swish",
+ input_size=80,
+ conv_glu_type="swish",
+ time_reduction=8,
+ bias_max_distance=1000,
+ bias_symmetric=False,
+ nemo_activation="relu",
+ nemo_conv_channels=1024,
+ downsample_rate=1,
+ initializer_range=0.02,
+ audio_token_id=200011,
+ feature_layer=-2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_blocks = num_blocks
+ self.num_attention_heads = num_attention_heads
+ self.activation = activation
+ self.chunk_size = chunk_size
+ self.left_chunk = left_chunk
+ self.dropout_rate = dropout_rate
+ self.ext_pw_out_channel = ext_pw_out_channel
+ self.depthwise_separable_out_channel = depthwise_separable_out_channel
+ self.depthwise_multiplier = depthwise_multiplier
+ self.kernel_size = kernel_size
+ self.conv_activation = conv_activation
+ self.input_size = input_size
+ self.conv_glu_type = conv_glu_type
+ self.time_reduction = time_reduction
+ self.bias_max_distance = bias_max_distance
+ self.bias_symmetric = bias_symmetric
+ self.nemo_activation = nemo_activation
+ self.nemo_conv_channels = nemo_conv_channels
+ self.downsample_rate = downsample_rate
+ self.initializer_range = initializer_range
+ self.audio_token_id = audio_token_id
+ self.feature_layer = feature_layer
+
+ nemo_final_size = self.input_size
+ for _ in range(int(math.log2(self.time_reduction))):
+ nemo_final_size = math.floor((nemo_final_size - 1) / 2 + 1)
+ self.nemo_final_size = nemo_final_size
+
+
+class Phi4MultimodalConfig(PretrainedConfig):
+ model_type = "phi4_multimodal"
+
+ @classmethod
+ def from_dict(cls, config_dict, **kwargs):
+ if config_dict.get("model_type") == "phi4mm":
+ config_dict = _convert_phi4mm_config(config_dict)
+ return super().from_dict(config_dict, **kwargs)
+
+ def __init__(
+ self,
+ vocab_size=200064,
+ hidden_size=3072,
+ intermediate_size=8192,
+ num_hidden_layers=32,
+ num_attention_heads=24,
+ num_key_value_heads=8,
+ resid_pdrop=0.0,
+ embd_pdrop=0.0,
+ attention_dropout=0.0,
+ hidden_act="silu",
+ max_position_embeddings=131072,
+ original_max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ rope_parameters=None,
+ bos_token_id=199999,
+ eos_token_id=None,
+ pad_token_id=199999,
+ sliding_window=None,
+ partial_rotary_factor=1.0,
+ vision_config=None,
+ audio_config=None,
+ attention_bias=False,
+ mlp_bias=False,
+ lm_head_bias=False,
+ vision_lora_rank=0,
+ vision_lora_alpha=1,
+ speech_lora_rank=0,
+ speech_lora_alpha=1,
+ **kwargs,
+ ):
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id if eos_token_id is not None else [199999, 200020],
+ pad_token_id=pad_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attention_dropout = attention_dropout
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.original_max_position_embeddings = original_max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.sliding_window = sliding_window
+ self.partial_rotary_factor = partial_rotary_factor
+ self.attention_bias = attention_bias
+ self.mlp_bias = mlp_bias
+ self.lm_head_bias = lm_head_bias
+ self.vision_lora_rank = vision_lora_rank
+ self.vision_lora_alpha = vision_lora_alpha
+ self.speech_lora_rank = speech_lora_rank
+ self.speech_lora_alpha = speech_lora_alpha
+ self._active_lora_adapter = None
+
+ if isinstance(vision_config, dict):
+ self.vision_config = Phi4MultimodalVisionConfig(**vision_config)
+ elif vision_config is None:
+ self.vision_config = Phi4MultimodalVisionConfig()
+ else:
+ self.vision_config = vision_config
+
+ if isinstance(audio_config, dict):
+ self.audio_config = Phi4MultimodalAudioConfig(**audio_config)
+ elif audio_config is None:
+ self.audio_config = Phi4MultimodalAudioConfig()
+ else:
+ self.audio_config = audio_config
+
+ # Build rope_parameters dict for compatibility with rope utils
+ self.rope_parameters = rope_parameters if rope_parameters is not None else self._build_rope_parameters()
+
+ def _build_rope_parameters(self):
+ rope_params = {
+ "rope_theta": self.rope_theta,
+ "partial_rotary_factor": self.partial_rotary_factor,
+ }
+ if self.rope_scaling is not None:
+ rope_params.update(self.rope_scaling)
+ if "rope_type" not in rope_params:
+ rope_params["rope_type"] = "longrope"
+ if rope_params.get("rope_type") in {"longrope", "yarn", "llama3"}:
+ rope_params.setdefault("original_max_position_embeddings", self.original_max_position_embeddings)
+ else:
+ rope_params["rope_type"] = "default"
+ return rope_params
diff --git a/paddleformers/transformers/phi4_multimodal/feature_extraction.py b/paddleformers/transformers/phi4_multimodal/feature_extraction.py
new file mode 100644
index 00000000000..37ca74fe42b
--- /dev/null
+++ b/paddleformers/transformers/phi4_multimodal/feature_extraction.py
@@ -0,0 +1,189 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for Phi-4 Multimodal audio."""
+
+import numpy as np
+import paddle
+from transformers.audio_utils import mel_filter_bank
+
+from ...utils.log import logger
+from ..audio_processing_utils import SequenceFeatureExtractor
+from ..feature_extraction_utils import BatchFeature
+
+
+class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
+ model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]
+
+ def __init__(
+ self,
+ feature_size=80,
+ sampling_rate=16000,
+ hop_length=160,
+ n_fft=512,
+ win_length=400,
+ preemphasis=0.97,
+ padding_value=0.0,
+ audio_compression_rate=8,
+ audio_downsample_rate=1,
+ audio_feat_stride=1,
+ mel_min_frequency=0,
+ mel_max_frequency=7690,
+ **kwargs,
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+ self.hop_length = hop_length
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.preemphasis = preemphasis
+ self.padding_value = padding_value
+ self.audio_compression_rate = audio_compression_rate
+ self.audio_downsample_rate = audio_downsample_rate
+ self.audio_feat_stride = audio_feat_stride
+ self.mel_filters = mel_filter_bank(
+ num_frequency_bins=self.n_fft // 2 + 1,
+ num_mel_filters=self.feature_size,
+ min_frequency=mel_min_frequency,
+ max_frequency=mel_max_frequency,
+ sampling_rate=self.sampling_rate,
+ triangularize_in_mel_space=True,
+ mel_scale="kaldi",
+ ).astype(np.float32)
+
+ def __call__(
+ self,
+ raw_speech,
+ sampling_rate=None,
+ pad_to_multiple_of=None,
+ padding="longest",
+ max_length=None,
+ truncation=False,
+ return_tensors=None,
+ return_attention_mask=True,
+ device="cpu",
+ **kwargs,
+ ):
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor was trained using a sampling rate of "
+ f"{self.sampling_rate}. Please provide audio sampled at {self.sampling_rate}, not {sampling_rate}."
+ )
+ if sampling_rate is None:
+ logger.warning(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`."
+ )
+
+ speech_list = self._as_mono_float_list(raw_speech)
+ audio_lengths = np.asarray([speech.shape[0] for speech in speech_list], dtype=np.int64)
+
+ if truncation and max_length is not None:
+ speech_list = [speech[:max_length] for speech in speech_list]
+ audio_lengths = np.minimum(audio_lengths, max_length)
+
+ padded_length = max(int(length) for length in audio_lengths)
+ padded_length = max(padded_length, self.win_length)
+ if padding not in (True, "longest", "max_length"):
+ padded_length = max(int(audio_lengths[0]), self.win_length)
+ if max_length is not None and padding == "max_length":
+ padded_length = max(max_length, self.win_length)
+ if pad_to_multiple_of is not None and padded_length % pad_to_multiple_of != 0:
+ padded_length = ((padded_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ waveform = np.full((len(speech_list), padded_length), self.padding_value, dtype=np.float32)
+ for idx, speech in enumerate(speech_list):
+ length = min(speech.shape[0], padded_length)
+ waveform[idx, :length] = speech[:length]
+
+ input_features = self._np_extract_fbank_features(waveform, audio_lengths)
+ feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1
+ feature_lengths = np.maximum(feature_lengths, 1) * self.audio_feat_stride
+ audio_embed_sizes = self._compute_audio_embed_size(feature_lengths)
+
+ feature_attention_mask = None
+ if return_attention_mask and len(feature_lengths) > 1:
+ max_feature_length = int(feature_lengths.max())
+ feature_attention_mask = np.arange(max_feature_length)[None, :] < feature_lengths[:, None]
+
+ data = {
+ "audio_input_features": input_features,
+ "audio_embed_sizes": audio_embed_sizes.astype(np.int64),
+ }
+ if feature_attention_mask is not None:
+ data["audio_attention_mask"] = feature_attention_mask
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def _np_extract_fbank_features(self, waveform, audio_lengths):
+ fft_window = np.hamming(self.win_length).astype(np.float64)
+ batch_features = []
+ max_frames = 0
+
+ for speech, length in zip(waveform, audio_lengths):
+ speech = speech[: max(int(length), self.win_length)]
+ if speech.shape[0] < self.win_length:
+ speech = np.pad(speech, (0, self.win_length - speech.shape[0]), constant_values=self.padding_value)
+ num_frames = (speech.shape[0] - self.win_length) // self.hop_length + 1
+ frames = np.stack(
+ [speech[i * self.hop_length : i * self.hop_length + self.win_length] for i in range(num_frames)],
+ axis=0,
+ )
+ frames_prev = np.roll(frames, 1, axis=-1)
+ frames_prev[:, 0] = frames_prev[:, 1]
+ frames = (frames - self.preemphasis * frames_prev) * 32768
+ spectrum = np.fft.rfft(fft_window * frames, n=self.n_fft, axis=1).astype(np.complex64)
+ spec_power = np.abs(spectrum) ** 2
+ log_spec = np.log(np.clip(spec_power @ self.mel_filters, a_min=1.0, a_max=None)).astype(np.float32)
+ batch_features.append(log_spec)
+ max_frames = max(max_frames, log_spec.shape[0])
+
+ padded = np.full((len(batch_features), max_frames, self.feature_size), self.padding_value, dtype=np.float32)
+ for idx, features in enumerate(batch_features):
+ padded[idx, : features.shape[0]] = features
+ return padded
+
+ def _compute_audio_embed_size(self, audio_frames):
+ integer = audio_frames // self.audio_compression_rate
+ remainder = audio_frames % self.audio_compression_rate
+ result = integer + (remainder > 0).astype(integer.dtype)
+ integer = result // self.audio_downsample_rate
+ remainder = result % self.audio_downsample_rate
+ return integer + (remainder > 0).astype(integer.dtype)
+
+ @staticmethod
+ def _as_mono_float_list(raw_speech):
+ if isinstance(raw_speech, paddle.Tensor):
+ raw_speech = raw_speech.detach().cpu().numpy()
+ if isinstance(raw_speech, np.ndarray):
+ if raw_speech.ndim == 1:
+ raw_speech = [raw_speech]
+ elif raw_speech.ndim == 2:
+ raw_speech = list(raw_speech)
+ else:
+ raw_speech = [raw_speech.mean(axis=-1)]
+ elif isinstance(raw_speech, (list, tuple)):
+ raw_speech = [
+ speech.detach().cpu().numpy() if isinstance(speech, paddle.Tensor) else speech for speech in raw_speech
+ ]
+ else:
+ raise ValueError(f"Unsupported audio input type: {type(raw_speech)}")
+
+ result = []
+ for speech in raw_speech:
+ speech = np.asarray(speech)
+ if speech.ndim > 1:
+ speech = speech.mean(axis=-1)
+ result.append(speech.astype(np.float32))
+ return result
+
+
+__all__ = ["Phi4MultimodalFeatureExtractor"]
diff --git a/paddleformers/transformers/phi4_multimodal/image_processor.py b/paddleformers/transformers/phi4_multimodal/image_processor.py
new file mode 100644
index 00000000000..80a17f46dc9
--- /dev/null
+++ b/paddleformers/transformers/phi4_multimodal/image_processor.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Phi-4 Multimodal."""
+
+import math
+
+import numpy as np
+import paddle
+from PIL import Image, ImageOps
+
+from ..feature_extraction_utils import BatchFeature
+from ..image_processing_utils import PaddleImageProcessingMixin
+from ..image_utils import PILImageResampling, make_flat_list_of_images
+
+
+class Phi4MultimodalImageProcessor(PaddleImageProcessingMixin):
+ model_input_names = ["image_pixel_values", "image_sizes", "image_attention_mask"]
+
+ def __init__(
+ self,
+ size=None,
+ patch_size=14,
+ dynamic_hd=36,
+ image_mean=None,
+ image_std=None,
+ do_resize=True,
+ do_rescale=True,
+ do_normalize=True,
+ do_convert_rgb=True,
+ resample=PILImageResampling.BICUBIC,
+ rescale_factor=1 / 255,
+ **kwargs,
+ ):
+ self.size = size or {"height": 448, "width": 448}
+ self.patch_size = patch_size
+ self.dynamic_hd = dynamic_hd
+ self.image_mean = image_mean or [0.5, 0.5, 0.5]
+ self.image_std = image_std or [0.5, 0.5, 0.5]
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.do_convert_rgb = do_convert_rgb
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ super().__init__(**kwargs)
+
+ def __call__(self, images, **kwargs):
+ return self.preprocess(images, **kwargs)
+
+ def preprocess(
+ self,
+ images,
+ size=None,
+ patch_size=None,
+ dynamic_hd=None,
+ image_mean=None,
+ image_std=None,
+ do_rescale=None,
+ do_normalize=None,
+ rescale_factor=None,
+ return_tensors=None,
+ **kwargs,
+ ):
+ size = size or self.size
+ if isinstance(size, int):
+ size = {"height": size, "width": size}
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ dynamic_hd = dynamic_hd if dynamic_hd is not None else self.dynamic_hd
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
+
+ height = size["height"]
+ width = size["width"]
+ if height != width:
+ raise ValueError("Phi4MultimodalImageProcessor only supports square sizes.")
+
+ images = make_flat_list_of_images(images)
+ mask_size = height // patch_size
+ images_transformed = []
+ masks_transformed = []
+ image_tokens = []
+ image_sizes = []
+
+ for image in images:
+ image = self._to_pil_image(image)
+ resized_image, attention_mask = self.dynamic_preprocess(
+ image, height, patch_size, mask_size, max_num=dynamic_hd
+ )
+ processed_image = self._to_chw_array(resized_image)
+ if do_rescale:
+ processed_image = processed_image * rescale_factor
+ if do_normalize:
+ mean = np.asarray(image_mean, dtype=np.float32)[:, None, None]
+ std = np.asarray(image_std, dtype=np.float32)[:, None, None]
+ processed_image = (processed_image - mean) / std
+
+ global_image = self._resize_chw(processed_image, height, height)
+ image_height, image_width = processed_image.shape[-2:]
+ mask_height, mask_width = attention_mask.shape[-2:]
+ global_attention_mask = np.ones((1, mask_size, mask_size), dtype=bool)
+
+ hd_image = processed_image.reshape(1, 3, image_height // height, height, image_width // width, width)
+ hd_image = hd_image.transpose(0, 2, 4, 1, 3, 5).reshape(-1, 3, height, width)
+
+ attention_mask = attention_mask.reshape(
+ mask_height // mask_size, mask_size, mask_width // mask_size, mask_size
+ )
+ attention_mask = attention_mask.transpose(0, 2, 1, 3).reshape(-1, mask_size, mask_size)
+
+ downsample_attention_mask = attention_mask[:, 0::2, 0::2]
+ pooled_mask_size = mask_size // 2 + mask_size % 2
+ downsample_attention_mask = downsample_attention_mask.reshape(
+ mask_height // mask_size,
+ mask_width // mask_size,
+ pooled_mask_size,
+ pooled_mask_size,
+ )
+ downsample_attention_mask = downsample_attention_mask.transpose(0, 2, 1, 3)
+ downsample_attention_mask = downsample_attention_mask.reshape(
+ downsample_attention_mask.shape[0] * downsample_attention_mask.shape[1],
+ downsample_attention_mask.shape[2] * downsample_attention_mask.shape[3],
+ )
+
+ base_feat_size = mask_size // 2 + mask_size % 2
+ num_img_tokens = (
+ base_feat_size**2
+ + 1
+ + int(downsample_attention_mask.sum().item())
+ + int(downsample_attention_mask[:, 0].sum().item())
+ + base_feat_size
+ )
+
+ hd_image = np.concatenate([global_image[None, ...], hd_image], axis=0)
+ attention_mask = np.concatenate([global_attention_mask, attention_mask], axis=0)
+
+ images_transformed.append(hd_image.astype(np.float32))
+ masks_transformed.append(attention_mask.astype(bool))
+ image_tokens.append(num_img_tokens)
+ image_sizes.append([image_height, image_width])
+
+ max_crops = max(image.shape[0] for image in images_transformed)
+ images_transformed = np.stack(
+ [self._pad_images_to_max_crops(image, max_crops) for image in images_transformed], axis=0
+ )
+ masks_transformed = np.stack(
+ [self._pad_masks_to_max_crops(mask, max_crops) for mask in masks_transformed], axis=0
+ )
+
+ data = {
+ "image_pixel_values": images_transformed,
+ "image_sizes": np.asarray(image_sizes, dtype=np.int64),
+ "image_attention_mask": masks_transformed,
+ "num_img_tokens": image_tokens,
+ }
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float("inf")
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ return best_ratio
+
+ def dynamic_preprocess(self, image, image_size, patch_size, mask_size, max_num=36, min_num=1):
+ orig_width, orig_height = image.size
+ w_crop_num = math.ceil(orig_width / float(image_size))
+ h_crop_num = math.ceil(orig_height / float(image_size))
+ if w_crop_num * h_crop_num > max_num:
+ aspect_ratio = orig_width / orig_height
+ target_ratios = {
+ (i, j)
+ for n in range(min_num, max_num + 1)
+ for i in range(1, n + 1)
+ for j in range(1, n + 1)
+ if min_num <= i * j <= max_num
+ }
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+ target_aspect_ratio = self.find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
+ )
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ else:
+ target_width = image_size * w_crop_num
+ target_height = image_size * h_crop_num
+ target_aspect_ratio = (w_crop_num, h_crop_num)
+
+ ratio_width = target_width / orig_width
+ ratio_height = target_height / orig_height
+ if ratio_width < ratio_height:
+ new_size = (target_width, int(orig_height * ratio_width))
+ padding_width = 0
+ padding_height = target_height - int(orig_height * ratio_width)
+ else:
+ new_size = (int(orig_width * ratio_height), target_height)
+ padding_width = target_width - int(orig_width * ratio_height)
+ padding_height = 0
+
+ attention_mask = np.ones(
+ (int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0])),
+ dtype=bool,
+ )
+ if padding_width >= patch_size:
+ attention_mask[:, -math.floor(padding_width / patch_size) :] = False
+ if padding_height >= patch_size:
+ attention_mask[-math.floor(padding_height / patch_size) :, :] = False
+
+ if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10:
+ raise ValueError(f"the aspect ratio is very extreme {new_size}")
+
+ image = image.resize((int(new_size[0]), int(new_size[1])), resample=self.resample)
+ resized_img = ImageOps.expand(
+ image, border=(0, 0, int(padding_width), int(padding_height)), fill=(255, 255, 255)
+ )
+ return resized_img, attention_mask
+
+ def _to_pil_image(self, image):
+ if isinstance(image, Image.Image):
+ pil_image = image
+ else:
+ if isinstance(image, paddle.Tensor):
+ image = image.detach().cpu().numpy()
+ image = np.asarray(image)
+ if image.ndim != 3:
+ raise ValueError(f"Expected image with 3 dimensions, got shape {image.shape}.")
+ if image.shape[0] in (1, 3) and image.shape[-1] not in (1, 3):
+ image = image.transpose(1, 2, 0)
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = image * 255
+ image = np.clip(image, 0, 255).astype(np.uint8)
+ pil_image = Image.fromarray(image)
+ return pil_image.convert("RGB") if self.do_convert_rgb else pil_image
+
+ @staticmethod
+ def _to_chw_array(image):
+ return np.asarray(image).astype(np.float32).transpose(2, 0, 1)
+
+ def _resize_chw(self, image, height, width):
+ resized = []
+ for channel in image:
+ pil_channel = Image.fromarray(channel.astype(np.float32), mode="F")
+ pil_channel = pil_channel.resize((width, height), resample=self.resample)
+ resized.append(np.asarray(pil_channel, dtype=np.float32))
+ return np.stack(resized, axis=0)
+
+ @staticmethod
+ def _pad_images_to_max_crops(images, max_crops):
+ if max_crops <= images.shape[0]:
+ return images
+ pad = np.zeros((max_crops - images.shape[0], *images.shape[1:]), dtype=images.dtype)
+ return np.concatenate([images, pad], axis=0)
+
+ @staticmethod
+ def _pad_masks_to_max_crops(masks, max_crops):
+ if max_crops <= masks.shape[0]:
+ return masks
+ pad = np.ones((max_crops - masks.shape[0], *masks.shape[1:]), dtype=masks.dtype)
+ return np.concatenate([masks, pad], axis=0)
+
+
+__all__ = ["Phi4MultimodalImageProcessor"]
diff --git a/paddleformers/transformers/phi4_multimodal/modeling.py b/paddleformers/transformers/phi4_multimodal/modeling.py
new file mode 100644
index 00000000000..295bfe31033
--- /dev/null
+++ b/paddleformers/transformers/phi4_multimodal/modeling.py
@@ -0,0 +1,1955 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Paddle Phi-4-Multimodal model."""
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import paddle
+import paddle.nn.functional as F
+from paddle import nn
+from paddle.distributed.fleet.utils import recompute
+
+from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
+from ...nn.criterion.interface import CriterionLayer
+from ...nn.embedding import Embedding as GeneralEmbedding
+from ...nn.linear import Linear as GeneralLinear
+from ...nn.lm_head import LMHead as GeneralLMHead
+from ...nn.pp_model import GeneralModelForCausalLMPipe
+from ...utils.log import logger
+from ..activations import ACT2FN
+from ..cache_utils import Cache, DynamicCache
+from ..masking_utils import create_causal_mask_and_row_indices
+from ..model_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ..model_utils import PretrainedModel, register_base_model
+from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from .configuration import (
+ Phi4MultimodalAudioConfig,
+ Phi4MultimodalConfig,
+ Phi4MultimodalVisionConfig,
+)
+
+# ======================= Utility functions =======================
+
+
+def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return paddle.concat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ rotary_dim = cos.shape[-1]
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
+
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
+
+ q_embed = paddle.concat([q_embed, q_pass], axis=-1)
+ k_embed = paddle.concat([k_embed, k_pass], axis=-1)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand([batch, num_key_value_heads, n_rep, slen, head_dim])
+ return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim])
+
+
+def _create_lora_parameter(layer: nn.Layer, shape):
+ return layer.create_parameter(
+ shape=shape,
+ default_initializer=nn.initializer.Constant(0.0),
+ )
+
+
+def _lora_delta(hidden_states: paddle.Tensor, lora_a: paddle.Tensor, lora_b: paddle.Tensor, scaling: float):
+ input_dtype = hidden_states.dtype
+ delta = F.linear(hidden_states, lora_a.astype(input_dtype))
+ delta = F.linear(delta, lora_b.astype(input_dtype))
+ return delta * scaling
+
+
+def _active_lora_adapter(config):
+ adapter = getattr(config, "_active_lora_adapter", None)
+ if adapter in ("vision", "speech"):
+ return adapter
+ return None
+
+
+def _lora_adapter_from_input_mode(input_mode, image_pixel_values=None, audio_input_features=None):
+ if input_mode is not None:
+ if isinstance(input_mode, paddle.Tensor):
+ input_mode = int(input_mode.flatten()[0].item())
+ if input_mode in (1, 3):
+ return "vision"
+ if input_mode == 2:
+ return "speech"
+ return None
+ if image_pixel_values is not None:
+ return "vision"
+ if audio_input_features is not None:
+ return "speech"
+ return None
+
+
+# ======================= Vision Encoder =======================
+
+
+class Phi4MultimodalVisionMLP(nn.Layer):
+ def __init__(self, config: Phi4MultimodalVisionConfig):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Phi4MultimodalVisionAttention(nn.Layer):
+ def __init__(self, config: Phi4MultimodalVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = list(input_shape) + [self.num_heads, self.head_dim]
+
+ query_states = self.q_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+ key_states = self.k_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+ value_states = self.v_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+
+ attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) * self.scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = F.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query_states.dtype)
+ if self.training and self.attention_dropout > 0:
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = paddle.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose([0, 2, 1, 3])
+
+ attn_output = attn_output.reshape(list(input_shape) + [-1])
+ attn_output = self.out_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Phi4MultimodalVisionEncoderLayer(nn.Layer):
+ def __init__(self, config: Phi4MultimodalVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
+ self.self_attn = Phi4MultimodalVisionAttention(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, epsilon=config.layer_norm_eps)
+ self.mlp = Phi4MultimodalVisionMLP(config)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ ) -> paddle.Tensor:
+ residual = hidden_states
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, _ = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Phi4MultimodalVisionEncoder(nn.Layer):
+ def __init__(self, config: Phi4MultimodalVisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.LayerList([Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+ def forward(
+ self,
+ inputs_embeds: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ output_hidden_states: bool = False,
+ ):
+ hidden_states = inputs_embeds
+ all_hidden_states = () if output_hidden_states else None
+
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+ hidden_states = encoder_layer(hidden_states, attention_mask)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ return hidden_states, all_hidden_states
+
+
+class Phi4MultimodalVisionEmbeddings(nn.Layer):
+ def __init__(self, config: Phi4MultimodalVisionConfig):
+ super().__init__()
+ self.config = config
+ self.patch_size = config.patch_size
+ self.num_patches_per_side = config.image_size // self.patch_size
+
+ self.patch_embedding = nn.Conv2D(
+ in_channels=config.num_channels,
+ out_channels=config.hidden_size,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ )
+ self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size)
+
+ def forward(self, pixel_values: paddle.Tensor, patch_attention_mask: paddle.Tensor) -> paddle.Tensor:
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
+
+ patch_embeds = self.patch_embedding(pixel_values)
+ embeddings = patch_embeds.flatten(2).transpose([0, 2, 1])
+
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
+ boundaries = paddle.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side).astype(
+ pixel_values.dtype
+ )
+ position_ids = paddle.full(
+ shape=[batch_size, max_nb_patches_h * max_nb_patches_w], fill_value=0, dtype="int64"
+ )
+
+ nb_patches_h = patch_attention_mask[:, :, 0].astype("int64").sum(axis=1)
+ nb_patches_w = patch_attention_mask[:, 0, :].astype("int64").sum(axis=1)
+
+ step_h = 1.0 / nb_patches_h.astype("float32")
+ step_w = 1.0 / nb_patches_w.astype("float32")
+
+ max_patches_h = patch_attention_mask.shape[1]
+ max_patches_w = patch_attention_mask.shape[2]
+ h_indices = paddle.arange(max_patches_h, dtype="float32")
+ w_indices = paddle.arange(max_patches_w, dtype="float32")
+
+ fractional_coords_h = h_indices[None, :] * step_h[:, None]
+ fractional_coords_w = w_indices[None, :] * step_w[:, None]
+
+ fractional_coords_h = paddle.clip(fractional_coords_h, max=(1.0 - 1e-6))
+ fractional_coords_w = paddle.clip(fractional_coords_w, max=(1.0 - 1e-6))
+
+ fractional_coords_h = fractional_coords_h.astype(pixel_values.dtype)
+ fractional_coords_w = fractional_coords_w.astype(pixel_values.dtype)
+
+ bucket_coords_h = paddle.bucketize(fractional_coords_h, boundaries, right=True)
+ bucket_coords_w = paddle.bucketize(fractional_coords_w, boundaries, right=True)
+
+ pos_ids = bucket_coords_h[:, :, None] * self.num_patches_per_side + bucket_coords_w[:, None, :]
+ pos_ids = pos_ids.reshape([batch_size, -1])
+
+ flat_mask = patch_attention_mask.reshape([batch_size, -1]).astype("bool")
+ for i in range(batch_size):
+ mask_i = flat_mask[i]
+ position_ids[i][mask_i] = pos_ids[i][mask_i]
+
+ embeddings = embeddings + self.position_embedding(position_ids)
+ return embeddings
+
+
+class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Layer):
+ def __init__(self, config: Phi4MultimodalVisionConfig):
+ super().__init__()
+ self.probe = self.create_parameter(
+ shape=[1, 1, config.hidden_size],
+ default_initializer=nn.initializer.Normal(std=1.0),
+ )
+ self.attention = nn.MultiHeadAttention(config.hidden_size, config.num_attention_heads)
+ self.layernorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
+ self.mlp = Phi4MultimodalVisionMLP(config)
+
+ def forward(self, hidden_state: paddle.Tensor, attention_mask: paddle.Tensor) -> paddle.Tensor:
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.expand([batch_size, -1, -1])
+
+ # attention_mask: [B, S] bool -> key_padding_mask for MHA
+ # Paddle MHA uses attn_mask, we need to convert
+ # ~attention_mask gives True for padded positions
+ key_padding_mask = ~attention_mask.astype("bool")
+ # Convert to float mask: 0 for valid, -inf for padded
+ attn_mask = key_padding_mask.astype(hidden_state.dtype) * paddle.finfo(hidden_state.dtype).min
+ attn_mask = attn_mask.unsqueeze([1, 2]) # [B, 1, 1, S]
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attn_mask)
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+class Phi4MultimodalVisionModel(nn.Layer):
+ def __init__(self, config: Phi4MultimodalVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embeddings = Phi4MultimodalVisionEmbeddings(config)
+ self.encoder = Phi4MultimodalVisionEncoder(config)
+ self.post_layernorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
+ self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config)
+
+ def forward(
+ self,
+ pixel_values: paddle.Tensor,
+ patch_attention_mask: Optional[paddle.Tensor] = None,
+ output_hidden_states: bool = False,
+ ):
+ batch_size = pixel_values.shape[0]
+ if patch_attention_mask is None:
+ patch_attention_mask = paddle.ones(
+ shape=[
+ batch_size,
+ pixel_values.shape[2] // self.config.patch_size,
+ pixel_values.shape[3] // self.config.patch_size,
+ ],
+ dtype="bool",
+ )
+
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
+
+ patch_attention_mask_flat = patch_attention_mask.reshape([batch_size, -1])
+ # Create bidirectional attention mask
+ mask_expanded = patch_attention_mask_flat.unsqueeze([1, 2]).astype(hidden_states.dtype)
+ attention_mask = (1.0 - mask_expanded) * paddle.finfo(hidden_states.dtype).min
+
+ last_hidden_state, all_hidden_states = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ output_hidden_states=output_hidden_states,
+ )
+
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooled_output = self.head(
+ hidden_state=last_hidden_state,
+ attention_mask=patch_attention_mask_flat,
+ )
+
+ return last_hidden_state, pooled_output, all_hidden_states
+
+
+class Phi4MultimodalImageEmbedding(nn.Layer):
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__()
+ self.config = config
+ self.layer_idx = config.vision_config.feature_layer
+ self.crop_size = config.vision_config.crop_size
+ self.image_dim_out = config.vision_config.hidden_size
+
+ n_patches = config.vision_config.image_size // config.vision_config.patch_size
+ if n_patches % 2 != 0:
+ self.img_processor_padding = nn.Pad2D([0, 1, 0, 1], mode="reflect")
+ n_patches += 1
+ self.num_img_tokens = (n_patches // 2) ** 2
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.img_processor = Phi4MultimodalVisionModel(config.vision_config)
+ self.image_token_compression = nn.AvgPool2D(kernel_size=2, stride=2)
+ self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size)
+ self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size)
+ self.global_img_feature_extensor = self.create_parameter(
+ shape=[1, 1, self.image_dim_out],
+ default_initializer=nn.initializer.Constant(0.0),
+ )
+ self.sub_img_feature_extensor = self.create_parameter(
+ shape=[1, 1, 1, self.image_dim_out],
+ default_initializer=nn.initializer.Constant(0.0),
+ )
+
+ def _repeat_sub_img_feature_extensor(self, repeat_height: int) -> paddle.Tensor:
+ return paddle.tile(self.sub_img_feature_extensor, repeat_times=[1, repeat_height, 1, 1])
+
+ def get_img_features(self, img_embeds: paddle.Tensor, attention_mask=None) -> paddle.Tensor:
+ _, _, all_hidden_states = self.img_processor(
+ img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True
+ )
+ img_feature = all_hidden_states[self.layer_idx]
+
+ patch_feature = img_feature
+ width = int(math.sqrt(patch_feature.shape[1]))
+ patch_feature = patch_feature.reshape([-1, width, width, patch_feature.shape[-1]])
+ # convert to NCHW
+ patch_feature = patch_feature.transpose([0, 3, 1, 2])
+ if hasattr(self, "img_processor_padding"):
+ patch_feature = self.img_processor_padding(patch_feature)
+ patch_feature = self.image_token_compression(patch_feature)
+ # convert to NHWC
+ patch_feature = patch_feature.transpose([0, 2, 3, 1])
+ patch_feature = patch_feature.reshape(
+ [-1, patch_feature.shape[1] * patch_feature.shape[2], patch_feature.shape[-1]]
+ )
+ return patch_feature
+
+ def forward(
+ self,
+ input_ids: paddle.Tensor,
+ inputs_embeds: paddle.Tensor,
+ image_pixel_values: paddle.Tensor,
+ image_sizes: Optional[paddle.Tensor] = None,
+ image_attention_mask: Optional[paddle.Tensor] = None,
+ ) -> paddle.Tensor:
+ image_pixel_values = image_pixel_values.astype(self.img_processor.embeddings.patch_embedding.weight.dtype)
+
+ target_dtype = self.img_projection_up.bias.dtype
+
+ batch_size = image_pixel_values.shape[0]
+
+ img_features = self.get_img_features(
+ image_pixel_values.flatten(0, 1),
+ attention_mask=image_attention_mask.flatten(0, 1).astype("bool")
+ if image_attention_mask is not None
+ else None,
+ )
+ base_feat_size = int(np.sqrt(img_features.shape[1]))
+ img_features = img_features.reshape([batch_size, -1, base_feat_size**2, self.image_dim_out])
+ image_sizes_flat = image_sizes.reshape([-1, 2])
+
+ output_imgs = []
+ for idx in range(batch_size):
+ height = int(image_sizes_flat[idx, 0].item())
+ width = int(image_sizes_flat[idx, 1].item())
+ height_ratio = height // self.crop_size
+ width_ratio = width // self.crop_size
+ area_ratio = height_ratio * width_ratio
+
+ global_img = img_features[idx, :1]
+ global_img = global_img.reshape([1, base_feat_size, base_feat_size, self.image_dim_out])
+ temporary_extensor = self._repeat_sub_img_feature_extensor(base_feat_size)
+ global_img = paddle.concat([global_img, temporary_extensor], axis=2).reshape([1, -1, self.image_dim_out])
+
+ sub_img = img_features[idx, 1:]
+ sub_img = sub_img[:area_ratio]
+ sub_img = (
+ sub_img.reshape([height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out])
+ .transpose([0, 2, 1, 3, 4])
+ .reshape([1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out])
+ )
+
+ if image_attention_mask is not None:
+ reshaped_image_attention_mask = (
+ image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2]
+ .reshape([height_ratio, width_ratio, base_feat_size, base_feat_size])
+ .transpose([0, 2, 1, 3])
+ .reshape([1, height_ratio * base_feat_size, width_ratio * base_feat_size])
+ )
+ reshaped_image_attention_mask_int = reshaped_image_attention_mask.astype("int64")
+ useful_height = int(reshaped_image_attention_mask_int[0, :, 0].sum().item())
+ useful_width = int(reshaped_image_attention_mask_int[0, 0, :].sum().item())
+ sub_img = sub_img[:, :useful_height, :useful_width]
+ temporary_extensor = self._repeat_sub_img_feature_extensor(useful_height)
+ else:
+ temporary_extensor = self._repeat_sub_img_feature_extensor(height_ratio * base_feat_size)
+
+ sub_img = paddle.concat([sub_img, temporary_extensor], axis=2).reshape([1, -1, self.image_dim_out])
+
+ output_imgs.append(paddle.concat([sub_img, self.global_img_feature_extensor, global_img], axis=1))
+
+ img_set_tensor = []
+ for output_img in output_imgs:
+ output_img = output_img.astype(target_dtype)
+ img_feature_proj = self.img_projection_up(output_img)
+ img_feature_proj = F.gelu(img_feature_proj)
+ img_feature_proj = self.img_projection_down(img_feature_proj)
+ img_set_tensor.append(img_feature_proj)
+
+ merged_img_set_tensor = paddle.concat(img_set_tensor, axis=1).squeeze(0)
+ merged_img_set_tensor = merged_img_set_tensor.astype(inputs_embeds.dtype)
+
+ positions = paddle.nonzero(input_ids == self.config.vision_config.image_token_id)
+ if positions.shape[0] > 0:
+ image_embeds = inputs_embeds.clone()
+ for i in range(positions.shape[0]):
+ batch_idx = positions[i, 0]
+ seq_idx = positions[i, 1]
+ image_embeds[batch_idx, seq_idx] = merged_img_set_tensor[i]
+ else:
+ image_embeds = inputs_embeds
+
+ image_embeds = self.drop(image_embeds)
+ return image_embeds
+
+
+# ======================= Audio Encoder =======================
+
+
+class Phi4MultimodalAudioMLP(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.act_fn = ACT2FN[config.activation]
+ self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = self.layer_norm(hidden_states)
+ up_states = self.gate_up_proj(hidden_states)
+ up_states, gate = up_states.chunk(2, axis=-1)
+ up_states = up_states * self.act_fn(gate)
+ up_states = self.dropout(up_states)
+ hidden_states = self.down_proj(up_states)
+ out = self.dropout(hidden_states)
+ return out
+
+
+class Phi4MultimodalAudioAttention(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.config = config
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_heads = config.num_attention_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.dropout_rate
+
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim)
+ self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim)
+ self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim)
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: paddle.Tensor,
+ ) -> paddle.Tensor:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = list(input_shape) + [self.num_heads, self.head_dim]
+
+ query_states = self.q_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+ key_states = self.k_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+ value_states = self.v_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+
+ attn_weights = paddle.matmul(query_states, key_states.transpose([0, 1, 3, 2])) * self.scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = F.softmax(attn_weights, axis=-1, dtype=paddle.float32).astype(query_states.dtype)
+ if self.training and self.attention_dropout > 0:
+ attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = paddle.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose([0, 2, 1, 3])
+ attn_output = attn_output.reshape(list(input_shape) + [-1])
+ attn_output = self.o_proj(attn_output)
+ return attn_output
+
+
+class Phi4MultimodalAudioDepthWiseSeparableConv1d(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0):
+ super().__init__()
+ self.dw_conv = nn.Conv1D(
+ config.hidden_size,
+ config.hidden_size * config.depthwise_multiplier,
+ config.kernel_size,
+ stride=1,
+ padding=padding,
+ groups=config.hidden_size,
+ )
+ self.pw_conv = nn.Conv1D(
+ config.hidden_size * config.depthwise_multiplier, config.depthwise_separable_out_channel, 1, 1, 0
+ )
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ return self.pw_conv(self.dw_conv(hidden_states))
+
+
+class Phi4MultimodalAudioGluPointWiseConv(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.config = config
+ self.output_dim = config.ext_pw_out_channel
+
+ self.ext_pw_conv_1d = nn.Conv1D(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1)
+ self.glu_act = ACT2FN[config.conv_glu_type]
+ self.b1 = self.create_parameter(
+ shape=[1, config.ext_pw_out_channel, 1],
+ default_initializer=nn.initializer.Constant(0.0),
+ )
+ self.b2 = self.create_parameter(
+ shape=[1, config.ext_pw_out_channel, 1],
+ default_initializer=nn.initializer.Constant(0.0),
+ )
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = hidden_states.transpose([0, 2, 1])
+ hidden_states = self.ext_pw_conv_1d(hidden_states)
+ out = hidden_states[:, 0 : self.output_dim, :] + self.b1
+ out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2)
+ return out.transpose([0, 2, 1])
+
+
+class Phi4MultimodalAudioConvModule(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.config = config
+ self.kernel_size = config.kernel_size
+
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+ self.glu = Phi4MultimodalAudioGluPointWiseConv(config)
+ self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeparableConv1d(config, padding=config.kernel_size - 1)
+ self.act = ACT2FN[config.conv_activation]
+ self.ext_pw_conv_1d = nn.Conv1D(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = self.glu(self.layer_norm(hidden_states))
+ hidden_states = self.dw_sep_conv_1d(hidden_states.transpose([0, 2, 1]))
+
+ if self.kernel_size > 1:
+ hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)]
+
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.ext_pw_conv_1d(hidden_states)
+ out = self.dropout(hidden_states.transpose([0, 2, 1]))
+ return out
+
+
+class Phi4MultimodalAudioConformerEncoderLayer(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.feed_forward_in = Phi4MultimodalAudioMLP(config)
+ self.self_attn = Phi4MultimodalAudioAttention(config)
+ self.conv = Phi4MultimodalAudioConvModule(config)
+ self.feed_forward_out = Phi4MultimodalAudioMLP(config)
+ self.layer_norm_att = nn.LayerNorm(config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: paddle.Tensor,
+ ) -> paddle.Tensor:
+ residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states)
+ hidden_states = self.layer_norm_att(residual)
+ hidden_states = residual + self.self_attn(hidden_states, attention_mask)
+ hidden_states = hidden_states + self.conv(hidden_states)
+ hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states)
+ out = self.layer_norm(hidden_states)
+ return out
+
+
+class Phi4MultimodalAudioNemoConvSubsampling(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.subsampling_factor = config.time_reduction
+ self.sampling_num = int(math.log2(self.subsampling_factor))
+ self.act_fn = ACT2FN[config.nemo_activation]
+ conv_channels = config.nemo_conv_channels
+
+ layers = [
+ nn.Conv2D(1, conv_channels, kernel_size=3, stride=2, padding=1),
+ self.act_fn,
+ ]
+ for _ in range(self.sampling_num - 1):
+ layers.extend(
+ [
+ nn.Conv2D(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels),
+ nn.Conv2D(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1),
+ self.act_fn,
+ ]
+ )
+
+ self.conv = nn.Sequential(*layers)
+ self.out = nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size)
+
+ def forward(self, hidden_states: paddle.Tensor, mask: Optional[paddle.Tensor]):
+ hidden_states = hidden_states.unsqueeze(1)
+ hidden_states = self.conv(hidden_states)
+
+ b, _, t, _ = hidden_states.shape
+ hidden_states = self.out(hidden_states.transpose([0, 2, 1, 3]).reshape([b, t, -1]))
+
+ if mask is None:
+ return hidden_states, None
+
+ max_audio_length = hidden_states.shape[1]
+ feature_lens = mask.sum(1)
+ padding_length = paddle.ceil(feature_lens / self.subsampling_factor).astype("int64")
+ arange_ = paddle.arange(0, max_audio_length, dtype="int64")
+ pad_mask = arange_.expand([padding_length.shape[0], -1]) < padding_length.unsqueeze(1)
+ return hidden_states, pad_mask.unsqueeze(1)
+
+
+class Phi4MultimodalAudioRelativeAttentionBias(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.max_distance = config.bias_max_distance
+ self.symmetric = config.bias_symmetric
+ self.num_buckets = self.max_distance
+ if not config.bias_symmetric:
+ self.num_buckets *= 2
+ self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads)
+
+ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+ max_pos = x.shape[1]
+ context_position = paddle.arange(max_pos, dtype="int64")[:, None]
+ memory_position = paddle.arange(max_pos, dtype="int64")[None, :]
+ relative_position = memory_position - context_position
+
+ relative_position = paddle.clip(relative_position, min=-self.max_distance, max=self.max_distance - 1)
+
+ if self.symmetric:
+ bias_idx = paddle.abs(relative_position)
+ else:
+ bias_idx = relative_position + self.num_buckets // 2
+
+ att_bias = self.bias_values(bias_idx)
+ att_bias = att_bias.transpose([2, 0, 1]).unsqueeze(0)
+ return att_bias
+
+
+class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.register_buffer("global_mean", paddle.zeros([config.input_size]))
+ self.register_buffer("global_invstd", paddle.ones([config.input_size]))
+
+ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
+ return (x - self.global_mean) * self.global_invstd
+
+
+def unfold_tensor(tensor: paddle.Tensor, max_seq_len: int) -> paddle.Tensor:
+ _, T, D = tensor.shape
+ n_chunks = T // max_seq_len
+ tensor = tensor[:, : n_chunks * max_seq_len, :]
+ tensor = tensor.reshape([-1, max_seq_len, D])
+ return tensor
+
+
+def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
+ import torch
+
+ chunk_start_idx_t = torch.tensor(chunk_start_idx, dtype=torch.long)
+ start_pad = torch.nn.functional.pad(chunk_start_idx_t, (1, 0))
+ end_pad = torch.nn.functional.pad(chunk_start_idx_t, (0, 1), value=x_len)
+ seq_range = torch.arange(0, x_len).unsqueeze(-1)
+ idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1]
+ seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
+ idx_left = idx - left_window
+ idx_left[idx_left < 0] = 0
+ boundary_left = start_pad[idx_left]
+ mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
+ idx_right = idx + right_window
+ idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
+ boundary_right = end_pad[idx_right]
+ mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
+ result = (mask_left & mask_right).numpy()
+ return paddle.to_tensor(result)
+
+
+class Phi4MultimodalAudioModel(nn.Layer):
+ def __init__(self, config: Phi4MultimodalAudioConfig):
+ super().__init__()
+ self.config = config
+ self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config)
+ self.embed = Phi4MultimodalAudioNemoConvSubsampling(config)
+ self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config)
+ self.encoders = nn.LayerList(
+ [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)]
+ )
+
+ def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
+ chunk_start_idx = np.arange(0, seq_len, chunk_size)
+ if self.training and np.random.rand() > 0.5:
+ chunk_start_idx = seq_len - chunk_start_idx
+ chunk_start_idx = chunk_start_idx[::-1]
+ chunk_start_idx = chunk_start_idx[:-1]
+ chunk_start_idx = np.insert(chunk_start_idx, 0, 0)
+
+ enc_streaming_mask = adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk)
+ enc_streaming_mask = enc_streaming_mask.unsqueeze(0).expand([batch_size, -1, -1])
+ return enc_streaming_mask
+
+ def forward_embeddings(self, hidden_states, masks):
+ seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction)
+ if seq_len <= 0:
+ raise ValueError(
+ f"The sequence length after time reduction is invalid: {seq_len}. Your input feature is too short."
+ )
+ batch_size = hidden_states.shape[0]
+ enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk)
+
+ hidden_states, masks = self.embed(hidden_states, masks)
+
+ streaming_mask = enc_streaming_mask
+ if streaming_mask is not None and masks is not None:
+ hs_mask = masks.astype("bool") & streaming_mask.astype("bool")
+ elif masks is not None:
+ hs_mask = masks
+ else:
+ hs_mask = streaming_mask
+
+ return hidden_states, hs_mask, masks
+
+ def calculate_hs_mask(self, hidden_states, mask):
+ max_audio_length = hidden_states.shape[1]
+ batch_size = hidden_states.shape[0]
+ enc_streaming_mask = self._streaming_mask(
+ max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk
+ )
+ if mask is None:
+ return enc_streaming_mask
+
+ feature_lens = mask.sum(1)
+ padding_length = feature_lens
+ pad_mask = paddle.arange(0, max_audio_length).expand([padding_length.shape[0], -1]) < padding_length.unsqueeze(
+ 1
+ )
+ pad_mask = pad_mask.unsqueeze(1)
+ pad_mask = pad_mask.astype("bool") & enc_streaming_mask.astype("bool")
+ return pad_mask
+
+ def forward(self, hidden_states: paddle.Tensor, mask: Optional[paddle.Tensor] = None, **kwargs):
+ hidden_states = self.encoder_embedding(hidden_states)
+ hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask)
+
+ unfolded = False
+ bs, seq_len, _ = hidden_states.shape
+ max_seq_len = 500
+ if seq_len > max_seq_len:
+ unfolded = True
+ if seq_len % max_seq_len > 0:
+ chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
+ else:
+ chunk_pad_size = 0
+ if chunk_pad_size > 0:
+ hidden_states = F.pad(hidden_states, [0, 0, 0, chunk_pad_size], data_format="NLC")
+
+ hidden_states = unfold_tensor(hidden_states, max_seq_len)
+ masks_unfold = None
+ if mask is not None:
+ subsampled_pad_mask = mask.squeeze(1)
+ extra_padded = F.pad(subsampled_pad_mask.astype("float32"), [0, chunk_pad_size])
+ extra_padded = extra_padded.unsqueeze(-1)
+ masks_unfold = unfold_tensor(extra_padded, max_seq_len)
+ masks_unfold = masks_unfold.squeeze(-1).astype("bool")
+ hs_mask = self.calculate_hs_mask(hidden_states, masks_unfold)
+
+ relative_attention_bias = self.relative_attention_bias_layer(hidden_states)
+ if hs_mask is not None:
+ attention_mask = hs_mask.unsqueeze(1).astype(hidden_states.dtype) + relative_attention_bias
+ else:
+ attention_mask = relative_attention_bias
+
+ for layer in self.encoders:
+ hidden_states = layer(hidden_states, attention_mask)
+
+ if unfolded:
+ embed_dim = hidden_states.shape[-1]
+ hidden_states = hidden_states.reshape([bs, -1, embed_dim])
+ if chunk_pad_size > 0:
+ hidden_states = hidden_states[:, :-chunk_pad_size, :]
+
+ return hidden_states
+
+
+class Phi4MultimodalAudioEmbedding(nn.Layer):
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__()
+ self.config = config
+ self.layer_idx = config.audio_config.feature_layer
+
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.encoder = Phi4MultimodalAudioModel(config.audio_config)
+ self.up_proj_for_speech = nn.Linear(
+ config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size
+ )
+ self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size)
+ self.up_proj_for_vision_speech = nn.Linear(
+ config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size
+ )
+ self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size)
+
+ def forward(
+ self,
+ input_ids: paddle.Tensor,
+ inputs_embeds: paddle.Tensor,
+ audio_input_features: paddle.Tensor,
+ audio_embed_sizes=None,
+ audio_attention_mask=None,
+ audio_projection_mode="speech",
+ ) -> paddle.Tensor:
+ positions = paddle.nonzero(input_ids == self.config.audio_config.audio_token_id)
+
+ up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech
+ down_proj = (
+ self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech
+ )
+
+ target_dtype = up_proj.bias.dtype
+ audio_input_features = audio_input_features.astype(target_dtype)
+
+ audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask)
+ audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states)
+ audio_encoder_hidden_states = F.gelu(audio_encoder_hidden_states)
+ audio_embeds = down_proj(audio_encoder_hidden_states)
+
+ merged_audio_embeds = paddle.concat(
+ [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], axis=0
+ )
+ merged_audio_embeds = merged_audio_embeds.astype(inputs_embeds.dtype)
+
+ audio_embeds_out = inputs_embeds.clone()
+ if positions.shape[0] > 0:
+ for i in range(positions.shape[0]):
+ batch_idx = positions[i, 0]
+ seq_idx = positions[i, 1]
+ audio_embeds_out[batch_idx, seq_idx] = merged_audio_embeds[i]
+
+ audio_embeds_out = self.drop(audio_embeds_out)
+ return audio_embeds_out
+
+
+# ======================= Text Decoder =======================
+
+
+class Phi4MultimodalRMSNorm(nn.Layer):
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
+ super().__init__()
+ self.weight = self.create_parameter(
+ shape=[hidden_size],
+ default_initializer=nn.initializer.Constant(1.0),
+ )
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ input_dtype = hidden_states.dtype
+ with paddle.amp.auto_cast(enable=False):
+ hidden_states = hidden_states.astype(paddle.float32)
+ variance = hidden_states.pow(2).mean(axis=-1, keepdim=True)
+ hidden_states = hidden_states * paddle.rsqrt(variance + self.variance_epsilon)
+ hidden_states = hidden_states.astype(input_dtype)
+ weight = self.weight.astype(input_dtype) if self.weight.dtype != input_dtype else self.weight
+ return weight * hidden_states
+
+
+class Phi4MultimodalMLP(nn.Layer):
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__()
+ self.config = config
+ self.gate_up_proj = GeneralLinear.create(
+ config.hidden_size,
+ 2 * config.intermediate_size,
+ has_bias=config.mlp_bias,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.down_proj = GeneralLinear.create(
+ config.intermediate_size,
+ config.hidden_size,
+ has_bias=config.mlp_bias,
+ config=config,
+ tp_plan="rowwise",
+ )
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self._init_lora(config.hidden_size, 2 * config.intermediate_size, config.intermediate_size, config.hidden_size)
+
+ def _init_lora(self, gate_in, gate_out, down_in, down_out):
+ for adapter in ("vision", "speech"):
+ rank = getattr(self.config, f"{adapter}_lora_rank", 0)
+ alpha = getattr(self.config, f"{adapter}_lora_alpha", 1)
+ if rank <= 0:
+ continue
+ setattr(self, f"{adapter}_gate_up_lora_A", _create_lora_parameter(self, [gate_in, rank]))
+ setattr(self, f"{adapter}_gate_up_lora_B", _create_lora_parameter(self, [rank, gate_out]))
+ setattr(self, f"{adapter}_down_lora_A", _create_lora_parameter(self, [down_in, rank]))
+ setattr(self, f"{adapter}_down_lora_B", _create_lora_parameter(self, [rank, down_out]))
+ setattr(self, f"{adapter}_lora_scaling", alpha / rank)
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ up_states = self.gate_up_proj(hidden_states)
+ adapter = _active_lora_adapter(self.config)
+ if adapter is not None and hasattr(self, f"{adapter}_gate_up_lora_A"):
+ up_states = up_states + _lora_delta(
+ hidden_states,
+ getattr(self, f"{adapter}_gate_up_lora_A"),
+ getattr(self, f"{adapter}_gate_up_lora_B"),
+ getattr(self, f"{adapter}_lora_scaling"),
+ )
+ gate, up_states = up_states.chunk(2, axis=-1)
+ up_states = up_states * self.activation_fn(gate)
+ hidden_states = self.down_proj(up_states)
+ if adapter is not None and hasattr(self, f"{adapter}_down_lora_A"):
+ hidden_states = hidden_states + _lora_delta(
+ up_states,
+ getattr(self, f"{adapter}_down_lora_A"),
+ getattr(self, f"{adapter}_down_lora_B"),
+ getattr(self, f"{adapter}_lora_scaling"),
+ )
+ return hidden_states
+
+
+class Phi4MultimodalRotaryEmbedding(nn.Layer):
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+ self.config = config
+
+ self.rope_type = config.rope_parameters.get("rope_type", "default")
+ rope_init_fn = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(config)
+
+ self.register_buffer("inv_freq", inv_freq, persistable=False)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistable=False)
+
+ def _match_torch_short_longrope_rounding(self, inv_freq):
+ rope_parameters = self.config.rope_parameters
+ short_factor = rope_parameters.get("short_factor", [])
+ if (
+ self.rope_type == "longrope"
+ and inv_freq.shape[0] == 48
+ and len(short_factor) == 48
+ and all(float(factor) == 1.0 for factor in short_factor)
+ and float(rope_parameters.get("rope_theta", 10000.0)) == 10000.0
+ ):
+ indices = paddle.to_tensor([7, 10, 14, 17, 20, 23, 25, 28, 31, 34, 37, 40, 43, 46], dtype="int64")
+ selected = paddle.gather(inv_freq, indices)
+ steps = [1, 1, 3, 3, 4, 4, 5, 5, 7, 7, 7, 9, 5, 6]
+ updates = selected
+ next_value = paddle.full_like(selected, float("inf"))
+ for step in range(max(steps)):
+ stepped = paddle.nextafter(updates, next_value)
+ mask = paddle.to_tensor([step < count for count in steps], dtype="bool")
+ updates = paddle.where(mask, stepped, updates)
+ inv_freq = paddle.scatter(inv_freq, indices, updates, overwrite=True)
+ return inv_freq
+
+ def _update_longrope_inv_freq(self, position_ids, device):
+ if self.rope_type != "longrope":
+ return
+
+ seq_len = int((paddle.max(position_ids) + 1).item())
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device, seq_len=seq_len)
+
+ original_max_position_embeddings = self.config.rope_parameters.get(
+ "original_max_position_embeddings",
+ getattr(self.config, "original_max_position_embeddings", self.config.max_position_embeddings),
+ )
+ if seq_len <= original_max_position_embeddings:
+ inv_freq = self._match_torch_short_longrope_rounding(inv_freq)
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistable=False)
+ self.register_buffer("inv_freq", inv_freq, persistable=False)
+
+ @staticmethod
+ def compute_default_rope_parameters(config, seq_len=None):
+ base = config.rope_parameters["rope_theta"]
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ dim = int(head_dim * partial_rotary_factor)
+ attention_factor = 1.0
+ inv_freq = 1.0 / (base ** (paddle.arange(0, dim, 2, dtype="int64").astype("float32") / dim))
+ return inv_freq, attention_factor
+
+ def forward(self, x, position_ids):
+ self._update_longrope_inv_freq(position_ids, x.place)
+ with paddle.amp.auto_cast(enable=False):
+ inv_freq_expanded = self.inv_freq.astype("float32")[None, :, None].expand([position_ids.shape[0], -1, 1])
+ position_ids_expanded = position_ids[:, None, :].astype("float32")
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose([0, 2, 1])
+ emb = paddle.concat((freqs, freqs), axis=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+ return cos.astype(x.dtype), sin.astype(x.dtype)
+
+
+class Phi4MultimodalAttention(nn.Layer):
+ def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.sequence_parallel = config.sequence_parallel
+
+ if config.tensor_model_parallel_size > 1:
+ assert self.num_heads % config.tensor_model_parallel_size == 0
+ assert self.num_key_value_heads % config.tensor_model_parallel_size == 0
+ self.num_heads = self.num_heads // config.tensor_model_parallel_size
+ self.num_key_value_heads = self.num_key_value_heads // config.tensor_model_parallel_size
+
+ op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim)
+ self.qkv_proj = GeneralLinear.create(
+ config.hidden_size,
+ op_size,
+ has_bias=config.attention_bias,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.o_proj = GeneralLinear.create(
+ config.num_attention_heads * self.head_dim,
+ config.hidden_size,
+ has_bias=config.attention_bias,
+ config=config,
+ tp_plan="rowwise",
+ )
+ self._init_lora(config.hidden_size, op_size, config.num_attention_heads * self.head_dim, config.hidden_size)
+
+ def _init_lora(self, qkv_in, qkv_out, o_in, o_out):
+ for adapter in ("vision", "speech"):
+ rank = getattr(self.config, f"{adapter}_lora_rank", 0)
+ alpha = getattr(self.config, f"{adapter}_lora_alpha", 1)
+ if rank <= 0:
+ continue
+ setattr(self, f"{adapter}_qkv_lora_A", _create_lora_parameter(self, [qkv_in, rank]))
+ setattr(self, f"{adapter}_qkv_lora_B", _create_lora_parameter(self, [rank, qkv_out]))
+ setattr(self, f"{adapter}_o_lora_A", _create_lora_parameter(self, [o_in, rank]))
+ setattr(self, f"{adapter}_o_lora_B", _create_lora_parameter(self, [rank, o_out]))
+ setattr(self, f"{adapter}_lora_scaling", alpha / rank)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
+ position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
+ if self.sequence_parallel:
+ seq_len, hidden_size = hidden_states.shape
+ batch_size = 1
+ else:
+ batch_size, seq_len, hidden_size = hidden_states.shape
+
+ qkv = self.qkv_proj(hidden_states)
+ adapter = _active_lora_adapter(self.config)
+ if adapter is not None and hasattr(self, f"{adapter}_qkv_lora_A"):
+ qkv = qkv + _lora_delta(
+ hidden_states,
+ getattr(self, f"{adapter}_qkv_lora_A"),
+ getattr(self, f"{adapter}_qkv_lora_B"),
+ getattr(self, f"{adapter}_lora_scaling"),
+ )
+ query_pos = self.num_heads * self.head_dim
+ kv_pos = self.num_key_value_heads * self.head_dim
+
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + kv_pos]
+ value_states = qkv[..., query_pos + kv_pos :]
+
+ if self.sequence_parallel:
+ query_states = query_states.reshape([seq_len, self.num_heads, self.head_dim])
+ key_states = key_states.reshape([seq_len, self.num_key_value_heads, self.head_dim])
+ value_states = value_states.reshape([seq_len, self.num_key_value_heads, self.head_dim])
+ query_states = query_states.transpose([1, 0, 2])
+ key_states = key_states.transpose([1, 0, 2])
+ value_states = value_states.transpose([1, 0, 2])
+ else:
+ query_states = query_states.reshape([batch_size, seq_len, self.num_heads, self.head_dim]).transpose(
+ [0, 2, 1, 3]
+ )
+ key_states = key_states.reshape([batch_size, seq_len, self.num_key_value_heads, self.head_dim]).transpose(
+ [0, 2, 1, 3]
+ )
+ value_states = value_states.reshape(
+ [batch_size, seq_len, self.num_key_value_heads, self.head_dim]
+ ).transpose([0, 2, 1, 3])
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
+
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query=query_states,
+ key=key_states,
+ value=value_states,
+ attention_mask=attention_mask,
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=getattr(self.config, "sliding_window", None),
+ )
+
+ if self.sequence_parallel:
+ attn_output = attn_output.reshape([seq_len, -1])
+ else:
+ attn_output = attn_output.reshape([batch_size, seq_len, -1])
+ o_input = attn_output
+ attn_output = self.o_proj(o_input)
+ if adapter is not None and hasattr(self, f"{adapter}_o_lora_A"):
+ attn_output = attn_output + _lora_delta(
+ o_input,
+ getattr(self, f"{adapter}_o_lora_A"),
+ getattr(self, f"{adapter}_o_lora_B"),
+ getattr(self, f"{adapter}_lora_scaling"),
+ )
+ return attn_output, attn_weights
+
+
+class Phi4MultimodalDecoderLayer(nn.Layer):
+ def __init__(self, config: Phi4MultimodalConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx)
+ self.mlp = Phi4MultimodalMLP(config)
+ self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.config = config
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: bool = False,
+ position_embeddings: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None,
+ output_attentions: bool = False,
+ ) -> paddle.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+ position_embeddings=position_embeddings,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + self.resid_attn_dropout(hidden_states)
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+ return outputs
+
+
+# ======================= Feature Embedding (bridge) =======================
+
+
+class Phi4MultimodalFeatureEmbedding(nn.Layer):
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__()
+ self.config = config
+ self.image_token_id = config.vision_config.image_token_id
+ self.audio_token_id = config.audio_config.audio_token_id
+ self.image_embed = Phi4MultimodalImageEmbedding(config)
+ self.audio_embed = Phi4MultimodalAudioEmbedding(config)
+
+ def forward(
+ self,
+ input_ids: paddle.Tensor,
+ inputs_embeds: paddle.Tensor,
+ image_pixel_values: Optional[paddle.Tensor] = None,
+ audio_input_features: Optional[paddle.Tensor] = None,
+ image_sizes=None,
+ image_attention_mask=None,
+ audio_embed_sizes=None,
+ audio_attention_mask=None,
+ ) -> paddle.Tensor:
+ image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1)
+ non_image_position_mask = ~image_position_mask
+
+ image_embeds = None
+ audio_embeds = None
+ if image_pixel_values is not None and (input_ids == self.image_token_id).any():
+ image_embeds = self.image_embed(
+ input_ids,
+ inputs_embeds,
+ image_pixel_values=image_pixel_values,
+ image_sizes=image_sizes,
+ image_attention_mask=image_attention_mask,
+ )
+ if audio_input_features is not None and (input_ids == self.audio_token_id).any():
+ audio_projection_mode = "vision" if image_pixel_values is not None else "speech"
+ audio_embeds = self.audio_embed(
+ input_ids,
+ inputs_embeds,
+ audio_input_features=audio_input_features,
+ audio_embed_sizes=audio_embed_sizes,
+ audio_attention_mask=audio_attention_mask,
+ audio_projection_mode=audio_projection_mode,
+ )
+
+ if image_embeds is not None and audio_embeds is not None:
+ inputs_embeds = image_embeds * image_position_mask.astype(
+ image_embeds.dtype
+ ) + audio_embeds * non_image_position_mask.astype(audio_embeds.dtype)
+ elif image_embeds is not None:
+ inputs_embeds = image_embeds
+ elif audio_embeds is not None:
+ inputs_embeds = audio_embeds
+
+ return inputs_embeds
+
+
+# ======================= Main Model =======================
+
+
+class Phi4MultimodalPreTrainedModel(PretrainedModel):
+ config_class = Phi4MultimodalConfig
+ base_model_prefix = "model"
+ transpose_weight_keys = [
+ "qkv_proj.weight",
+ "o_proj.weight",
+ "gate_up_proj.weight",
+ "down_proj.weight",
+ "img_projection_up.weight",
+ "img_projection_down.weight",
+ "up_proj_for_speech.weight",
+ "down_proj_for_speech.weight",
+ "up_proj_for_vision_speech.weight",
+ "down_proj_for_vision_speech.weight",
+ ]
+
+ @classmethod
+ def _gen_aoa_config(cls, config: Phi4MultimodalConfig):
+ model_prefix = "" if cls == cls.base_model_class else "model."
+ aoa_config = {"aoa_statements": []}
+ stmts = aoa_config["aoa_statements"]
+
+ # Embedding and norm
+ stmts.append(f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight")
+ stmts.append(f"model.norm.weight -> {model_prefix}norm.weight")
+ if cls != cls.base_model_class:
+ stmts.append("model.embed_tokens.weight -> lm_head.weight")
+
+ # Decoder layers
+ for layer_id in range(config.num_hidden_layers):
+ lp = f"model.layers.{layer_id}"
+ tp = f"{model_prefix}layers.{layer_id}"
+ stmts.append(f"{lp}.input_layernorm.weight -> {tp}.input_layernorm.weight")
+ stmts.append(f"{lp}.post_attention_layernorm.weight -> {tp}.post_attention_layernorm.weight")
+ stmts.append(f"{lp}.mlp.down_proj.base_layer.weight^T -> {tp}.mlp.down_proj.weight")
+ stmts.append(f"{lp}.self_attn.o_proj.base_layer.weight^T -> {tp}.self_attn.o_proj.weight")
+ stmts.append(
+ f"{lp}.self_attn.qkv_proj.base_layer.weight^T -> {tp}.self_attn.qkv_proj.weight, "
+ f"fused_qkv_old, num_heads={config.num_attention_heads}, "
+ f"num_key_value_groups={config.num_key_value_heads}, axis=1"
+ )
+ stmts.append(f"{lp}.mlp.gate_up_proj.base_layer.weight^T -> {tp}.mlp.gate_up_proj.weight, fused_ffn")
+
+ lora_specs = [
+ ("self_attn", "qkv_proj", "qkv"),
+ ("self_attn", "o_proj", "o"),
+ ("mlp", "gate_up_proj", "gate_up"),
+ ("mlp", "down_proj", "down"),
+ ]
+ for adapter in ("vision", "speech"):
+ for block, src_proj, dst_proj in lora_specs:
+ rank = getattr(config, f"{adapter}_lora_rank", 0)
+ if rank > 0:
+ stmts.append(
+ f"{lp}.{block}.{src_proj}.lora_A.{adapter}.weight^T -> "
+ f"{tp}.{block}.{adapter}_{dst_proj}_lora_A"
+ )
+ stmts.append(
+ f"{lp}.{block}.{src_proj}.lora_B.{adapter}.weight^T -> "
+ f"{tp}.{block}.{adapter}_{dst_proj}_lora_B"
+ )
+
+ # Vision encoder
+ vis_prefix_src = "model.embed_tokens_extend.image_embed"
+ vis_prefix_dst = f"{model_prefix}embed_tokens_extend.image_embed"
+
+ stmts.append(f"{vis_prefix_src}.glb_GN -> {vis_prefix_dst}.global_img_feature_extensor")
+ stmts.append(f"{vis_prefix_src}.sub_GN -> {vis_prefix_dst}.sub_img_feature_extensor")
+ stmts.append(f"{vis_prefix_src}.img_projection.0.weight^T -> {vis_prefix_dst}.img_projection_up.weight")
+ stmts.append(f"{vis_prefix_src}.img_projection.0.bias -> {vis_prefix_dst}.img_projection_up.bias")
+ stmts.append(f"{vis_prefix_src}.img_projection.2.weight^T -> {vis_prefix_dst}.img_projection_down.weight")
+ stmts.append(f"{vis_prefix_src}.img_projection.2.bias -> {vis_prefix_dst}.img_projection_down.bias")
+
+ # Vision processor layers
+ vp_src = f"{vis_prefix_src}.img_processor"
+ vp_dst = f"{vis_prefix_dst}.img_processor"
+ stmts.append(f"{vp_src}.embeddings.patch_embedding.weight -> {vp_dst}.embeddings.patch_embedding.weight")
+ stmts.append(f"{vp_src}.embeddings.patch_embedding.bias -> {vp_dst}.embeddings.patch_embedding.bias")
+ stmts.append(f"{vp_src}.embeddings.position_embedding.weight -> {vp_dst}.embeddings.position_embedding.weight")
+ stmts.append(f"{vp_src}.post_layernorm.weight -> {vp_dst}.post_layernorm.weight")
+ stmts.append(f"{vp_src}.post_layernorm.bias -> {vp_dst}.post_layernorm.bias")
+
+ # Vision head (MultiheadAttentionPoolingHead)
+ stmts.append(f"{vp_src}.head.probe -> {vp_dst}.head.probe")
+ stmts.append(f"{vp_src}.head.layernorm.weight -> {vp_dst}.head.layernorm.weight")
+ stmts.append(f"{vp_src}.head.layernorm.bias -> {vp_dst}.head.layernorm.bias")
+ stmts.append(f"{vp_src}.head.mlp.fc1.weight^T -> {vp_dst}.head.mlp.fc1.weight")
+ stmts.append(f"{vp_src}.head.mlp.fc1.bias -> {vp_dst}.head.mlp.fc1.bias")
+ stmts.append(f"{vp_src}.head.mlp.fc2.weight^T -> {vp_dst}.head.mlp.fc2.weight")
+ stmts.append(f"{vp_src}.head.mlp.fc2.bias -> {vp_dst}.head.mlp.fc2.bias")
+
+ for i in range(config.vision_config.num_hidden_layers):
+ vs = f"{vp_src}.encoder.layers.{i}"
+ vd = f"{vp_dst}.encoder.layers.{i}"
+ stmts.append(f"{vs}.layer_norm1.weight -> {vd}.layer_norm1.weight")
+ stmts.append(f"{vs}.layer_norm1.bias -> {vd}.layer_norm1.bias")
+ stmts.append(f"{vs}.layer_norm2.weight -> {vd}.layer_norm2.weight")
+ stmts.append(f"{vs}.layer_norm2.bias -> {vd}.layer_norm2.bias")
+ stmts.append(f"{vs}.self_attn.q_proj.weight^T -> {vd}.self_attn.q_proj.weight")
+ stmts.append(f"{vs}.self_attn.q_proj.bias -> {vd}.self_attn.q_proj.bias")
+ stmts.append(f"{vs}.self_attn.k_proj.weight^T -> {vd}.self_attn.k_proj.weight")
+ stmts.append(f"{vs}.self_attn.k_proj.bias -> {vd}.self_attn.k_proj.bias")
+ stmts.append(f"{vs}.self_attn.v_proj.weight^T -> {vd}.self_attn.v_proj.weight")
+ stmts.append(f"{vs}.self_attn.v_proj.bias -> {vd}.self_attn.v_proj.bias")
+ stmts.append(f"{vs}.self_attn.out_proj.weight^T -> {vd}.self_attn.out_proj.weight")
+ stmts.append(f"{vs}.self_attn.out_proj.bias -> {vd}.self_attn.out_proj.bias")
+ stmts.append(f"{vs}.mlp.fc1.weight^T -> {vd}.mlp.fc1.weight")
+ stmts.append(f"{vs}.mlp.fc1.bias -> {vd}.mlp.fc1.bias")
+ stmts.append(f"{vs}.mlp.fc2.weight^T -> {vd}.mlp.fc2.weight")
+ stmts.append(f"{vs}.mlp.fc2.bias -> {vd}.mlp.fc2.bias")
+
+ # Audio encoder
+ aud_prefix_src = "model.embed_tokens_extend.audio_embed"
+ aud_prefix_dst = f"{model_prefix}embed_tokens_extend.audio_embed"
+
+ # Audio projections
+ audio_projection_map = [
+ ("audio_projection.speech.0", "up_proj_for_speech"),
+ ("audio_projection.speech.2", "down_proj_for_speech"),
+ ("audio_projection.vision.0", "up_proj_for_vision_speech"),
+ ("audio_projection.vision.2", "down_proj_for_vision_speech"),
+ ]
+ for src_proj, dst_proj in audio_projection_map:
+ stmts.append(f"{aud_prefix_src}.{src_proj}.weight^T -> {aud_prefix_dst}.{dst_proj}.weight")
+ stmts.append(f"{aud_prefix_src}.{src_proj}.bias -> {aud_prefix_dst}.{dst_proj}.bias")
+
+ # Audio encoder internals
+ ae_src = f"{aud_prefix_src}.encoder"
+ ae_dst = f"{aud_prefix_dst}.encoder"
+
+ stmts.append(f"{ae_src}.encoder_embedding.global_mean -> {ae_dst}.encoder_embedding.global_mean")
+ stmts.append(f"{ae_src}.encoder_embedding.global_invstd -> {ae_dst}.encoder_embedding.global_invstd")
+ stmts.append(
+ f"{ae_src}.relative_attention_bias_layer.bias_values.weight -> {ae_dst}.relative_attention_bias_layer.bias_values.weight"
+ )
+
+ # Nemo conv subsampling
+ stmts.append(f"{ae_src}.embed.conv.0.weight -> {ae_dst}.embed.conv.0.weight")
+ stmts.append(f"{ae_src}.embed.conv.0.bias -> {ae_dst}.embed.conv.0.bias")
+ for conv_idx in [2, 3, 5, 6]:
+ stmts.append(f"{ae_src}.embed.conv.{conv_idx}.weight -> {ae_dst}.embed.conv.{conv_idx}.weight")
+ stmts.append(f"{ae_src}.embed.conv.{conv_idx}.bias -> {ae_dst}.embed.conv.{conv_idx}.bias")
+ stmts.append(f"{ae_src}.embed.out.weight^T -> {ae_dst}.embed.out.weight")
+ stmts.append(f"{ae_src}.embed.out.bias -> {ae_dst}.embed.out.bias")
+
+ # Audio conformer encoder layers
+ for i in range(config.audio_config.num_blocks):
+ al_src = f"{ae_src}.encoders.{i}"
+ al_dst = f"{ae_dst}.encoders.{i}"
+
+ # feed_forward_in / feed_forward_out
+ for ff in ["feed_forward_in", "feed_forward_out"]:
+ stmts.append(f"{al_src}.{ff}.layer_norm.weight -> {al_dst}.{ff}.layer_norm.weight")
+ stmts.append(f"{al_src}.{ff}.layer_norm.bias -> {al_dst}.{ff}.layer_norm.bias")
+ stmts.append(f"{al_src}.{ff}.net.0.linear.weight^T -> {al_dst}.{ff}.gate_up_proj.weight")
+ stmts.append(f"{al_src}.{ff}.net.0.linear.bias -> {al_dst}.{ff}.gate_up_proj.bias")
+ stmts.append(f"{al_src}.{ff}.net.2.weight^T -> {al_dst}.{ff}.down_proj.weight")
+ stmts.append(f"{al_src}.{ff}.net.2.bias -> {al_dst}.{ff}.down_proj.bias")
+
+ # self_attn
+ stmts.append(f"{al_src}.self_attn.linear_q.weight^T -> {al_dst}.self_attn.q_proj.weight")
+ stmts.append(f"{al_src}.self_attn.linear_q.bias -> {al_dst}.self_attn.q_proj.bias")
+ stmts.append(f"{al_src}.self_attn.linear_k.weight^T -> {al_dst}.self_attn.k_proj.weight")
+ stmts.append(f"{al_src}.self_attn.linear_k.bias -> {al_dst}.self_attn.k_proj.bias")
+ stmts.append(f"{al_src}.self_attn.linear_v.weight^T -> {al_dst}.self_attn.v_proj.weight")
+ stmts.append(f"{al_src}.self_attn.linear_v.bias -> {al_dst}.self_attn.v_proj.bias")
+ stmts.append(f"{al_src}.self_attn.linear_out.weight^T -> {al_dst}.self_attn.o_proj.weight")
+ stmts.append(f"{al_src}.self_attn.linear_out.bias -> {al_dst}.self_attn.o_proj.bias")
+
+ # conv module
+ stmts.append(f"{al_src}.conv.layer_norm.weight -> {al_dst}.conv.layer_norm.weight")
+ stmts.append(f"{al_src}.conv.layer_norm.bias -> {al_dst}.conv.layer_norm.bias")
+ stmts.append(f"{al_src}.conv.glu.ext_pw_conv_1d.weight -> {al_dst}.conv.glu.ext_pw_conv_1d.weight")
+ stmts.append(f"{al_src}.conv.glu.ext_pw_conv_1d.bias -> {al_dst}.conv.glu.ext_pw_conv_1d.bias")
+ stmts.append(f"{al_src}.conv.glu.b1 -> {al_dst}.conv.glu.b1")
+ stmts.append(f"{al_src}.conv.glu.b2 -> {al_dst}.conv.glu.b2")
+ stmts.append(f"{al_src}.conv.dw_sep_conv_1d.dw_conv.weight -> {al_dst}.conv.dw_sep_conv_1d.dw_conv.weight")
+ stmts.append(f"{al_src}.conv.dw_sep_conv_1d.dw_conv.bias -> {al_dst}.conv.dw_sep_conv_1d.dw_conv.bias")
+ stmts.append(f"{al_src}.conv.dw_sep_conv_1d.pw_conv.weight -> {al_dst}.conv.dw_sep_conv_1d.pw_conv.weight")
+ stmts.append(f"{al_src}.conv.dw_sep_conv_1d.pw_conv.bias -> {al_dst}.conv.dw_sep_conv_1d.pw_conv.bias")
+ stmts.append(f"{al_src}.conv.ext_pw_conv_1d.weight -> {al_dst}.conv.ext_pw_conv_1d.weight")
+ stmts.append(f"{al_src}.conv.ext_pw_conv_1d.bias -> {al_dst}.conv.ext_pw_conv_1d.bias")
+
+ # layer norms
+ stmts.append(f"{al_src}.layer_norm_att.weight -> {al_dst}.layer_norm_att.weight")
+ stmts.append(f"{al_src}.layer_norm_att.bias -> {al_dst}.layer_norm_att.bias")
+ stmts.append(f"{al_src}.layer_norm.weight -> {al_dst}.layer_norm.weight")
+ stmts.append(f"{al_src}.layer_norm.bias -> {al_dst}.layer_norm.bias")
+
+ # Vision head attention (nn.MultiHeadAttention has different weight names in Paddle)
+ # The HF model uses torch.nn.MultiheadAttention with in_proj_weight/in_proj_bias
+ # We'll handle this mapping for the vision pooling head attention
+ head_src = f"{vp_src}.head.attention"
+ head_dst = f"{vp_dst}.head.attention"
+ stmts.append(f"{head_src}.in_proj_weight^T -> {head_dst}.in_proj_weight_t")
+ stmts.append(f"{head_src}.in_proj_bias -> {head_dst}.in_proj_bias_t")
+ stmts.append(f"{head_src}.out_proj.weight^T -> {head_dst}.out_proj.weight")
+ stmts.append(f"{head_src}.out_proj.bias -> {head_dst}.out_proj.bias")
+
+ return aoa_config
+
+ @classmethod
+ def _gen_inv_aoa_config(cls, config: Phi4MultimodalConfig):
+ model_prefix = "" if cls == cls.base_model_class else "model."
+ aoa_statements = []
+
+ # Decoder layers
+ for layer_id in range(config.num_hidden_layers):
+ tp = f"{model_prefix}layers.{layer_id}"
+ lp = f"model.layers.{layer_id}"
+ aoa_statements.append(f"{tp}.input_layernorm.weight -> {lp}.input_layernorm.weight")
+ aoa_statements.append(f"{tp}.post_attention_layernorm.weight -> {lp}.post_attention_layernorm.weight")
+ aoa_statements.append(f"{tp}.mlp.down_proj.weight^T -> {lp}.mlp.down_proj.weight")
+ aoa_statements.append(f"{tp}.self_attn.o_proj.weight^T -> {lp}.self_attn.o_proj.weight")
+ aoa_statements.append(
+ f"{tp}.self_attn.qkv_proj.weight -> {lp}.self_attn.qkv_proj.weight, "
+ f"fused_qkv_old, num_heads={config.num_attention_heads}, "
+ f"num_key_value_groups={config.num_key_value_heads}, axis=1"
+ )
+ aoa_statements.append(f"{tp}.self_attn.qkv_proj.weight^T -> {lp}.self_attn.qkv_proj.weight")
+ aoa_statements.append(f"{tp}.mlp.gate_up_proj.weight^T -> {lp}.mlp.gate_up_proj.weight, fused_ffn")
+
+ aoa_statements.append(f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight")
+ aoa_statements.append(f"{model_prefix}norm.weight -> model.norm.weight")
+
+ return {"aoa_statements": aoa_statements}
+
+
+@register_base_model
+class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel):
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.config = config
+ self.sequence_parallel = config.sequence_parallel
+
+ self.embed_tokens = GeneralEmbedding.create(
+ config=config, num_embeddings=config.vocab_size, embedding_dim=config.hidden_size
+ )
+ self.layers = nn.LayerList(
+ [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Phi4MultimodalRotaryEmbedding(config)
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
+ self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config)
+
+ @paddle.jit.not_to_static
+ def recompute_training_full(self, layer_module, hidden_states, *args):
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = recompute(create_custom_forward(layer_module), hidden_states, *args)
+ return hidden_states
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ image_pixel_values: Optional[paddle.Tensor] = None,
+ image_sizes: Optional[paddle.Tensor] = None,
+ image_attention_mask=None,
+ audio_input_features: Optional[paddle.Tensor] = None,
+ audio_embed_sizes=None,
+ audio_attention_mask=None,
+ input_mode=None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ attn_mask_startend_row_indices=None,
+ **kwargs,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ previous_adapter = getattr(self.config, "_active_lora_adapter", None)
+ self.config._active_lora_adapter = _lora_adapter_from_input_mode(
+ input_mode,
+ image_pixel_values=image_pixel_values,
+ audio_input_features=audio_input_features,
+ )
+
+ try:
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens_extend(
+ input_ids,
+ inputs_embeds,
+ image_pixel_values=image_pixel_values,
+ audio_input_features=audio_input_features,
+ image_sizes=image_sizes,
+ image_attention_mask=image_attention_mask,
+ audio_embed_sizes=audio_embed_sizes,
+ audio_attention_mask=audio_attention_mask,
+ )
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+ cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+
+ if position_ids is None:
+ position_ids = (
+ paddle.arange(seq_length, dtype="int64").unsqueeze(0).expand([batch_size, -1]) + cache_length
+ )
+
+ # Create causal mask
+ mask_kwargs = {
+ "config": self.config,
+ "inputs_embeds": inputs_embeds,
+ "batch_size": batch_size,
+ "seq_length": seq_length,
+ "cache_length": cache_length,
+ "attention_mask": attention_mask,
+ "attn_mask_startend_row_indices": attn_mask_startend_row_indices,
+ "prepare_decoder_attention_mask": self._prepare_decoder_attention_mask,
+ }
+ causal_mask, attn_mask_startend_row_indices = create_causal_mask_and_row_indices(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ has_gradient = not hidden_states.stop_gradient
+ if (
+ self.config.recompute_granularity == "full"
+ and self.config.recompute_method == "uniform"
+ and self.config.recompute_num_layers == 1
+ and has_gradient
+ ):
+ layer_outputs = self.recompute_training_full(
+ decoder_layer,
+ hidden_states,
+ causal_mask,
+ attn_mask_startend_row_indices,
+ position_ids,
+ past_key_values,
+ use_cache,
+ position_embeddings,
+ output_attentions,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states=hidden_states,
+ attention_mask=causal_mask,
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_embeddings=position_embeddings,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+ finally:
+ self.config._active_lora_adapter = previous_adapter
+
+
+class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: Phi4MultimodalConfig):
+ super().__init__(config)
+ self.model = Phi4MultimodalModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = GeneralLMHead(config)
+ self.criterion = CriterionLayer(config)
+ self._apply_multimodal_freeze_config()
+
+ def _apply_multimodal_freeze_config(self):
+ freeze_prefixes = []
+ if getattr(self.config, "freeze_vision_model", False):
+ freeze_prefixes.extend(
+ [
+ "model.embed_tokens_extend.image_embed.img_processor",
+ "model.embed_tokens_extend.audio_embed.encoder",
+ ]
+ )
+ if getattr(self.config, "freeze_vision_projection", False):
+ freeze_prefixes.extend(
+ [
+ "model.embed_tokens_extend.image_embed.img_projection_up",
+ "model.embed_tokens_extend.image_embed.img_projection_down",
+ "model.embed_tokens_extend.image_embed.global_img_feature_extensor",
+ "model.embed_tokens_extend.image_embed.sub_img_feature_extensor",
+ "model.embed_tokens_extend.audio_embed.up_proj_for_speech",
+ "model.embed_tokens_extend.audio_embed.down_proj_for_speech",
+ "model.embed_tokens_extend.audio_embed.up_proj_for_vision_speech",
+ "model.embed_tokens_extend.audio_embed.down_proj_for_vision_speech",
+ ]
+ )
+ if getattr(self.config, "freeze_language_model", False):
+ freeze_prefixes.extend(["model.embed_tokens", "model.layers", "model.norm", "lm_head"])
+
+ freeze_multimodal_adapters = (
+ getattr(self.config, "freeze_vision_model", False)
+ and getattr(self.config, "freeze_vision_projection", False)
+ and not getattr(self.config, "freeze_language_model", False)
+ )
+
+ if not freeze_prefixes:
+ return
+
+ frozen = 0
+ for name, param in self.named_parameters():
+ if any(name.startswith(prefix) for prefix in freeze_prefixes) or (
+ freeze_multimodal_adapters and "_lora_" in name
+ ):
+ param.stop_gradient = True
+ frozen += 1
+ logger.info(f"Phi-4 multimodal freeze_config applied. Frozen parameter tensors: {frozen}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
+ model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
+ model._apply_multimodal_freeze_config()
+ return model
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ use_cache=True,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ image_pixel_values=None,
+ image_sizes=None,
+ image_attention_mask=None,
+ audio_input_features=None,
+ audio_embed_sizes=None,
+ audio_attention_mask=None,
+ input_mode=None,
+ position_ids=None,
+ **kwargs,
+ ):
+ if input_mode is None:
+ has_image = image_pixel_values is not None
+ has_audio = audio_input_features is not None
+ if has_image and has_audio:
+ input_mode = paddle.to_tensor([3], dtype="int64")
+ elif has_image:
+ input_mode = paddle.to_tensor([1], dtype="int64")
+ elif has_audio:
+ input_mode = paddle.to_tensor([2], dtype="int64")
+
+ batch_size, seq_length = input_ids.shape
+ if position_ids is None:
+ position_ids = paddle.arange(seq_length, dtype="int64").unsqueeze(0).expand([batch_size, -1])
+ if past_key_values:
+ input_ids = input_ids[:, -1].unsqueeze(axis=-1)
+ position_ids = position_ids[:, -1].unsqueeze(axis=-1)
+
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "input_mode": input_mode,
+ }
+ )
+
+ if past_key_values is None:
+ model_inputs.update(
+ {
+ "image_pixel_values": image_pixel_values,
+ "image_sizes": image_sizes,
+ "image_attention_mask": image_attention_mask,
+ "audio_input_features": audio_input_features,
+ "audio_embed_sizes": audio_embed_sizes,
+ "audio_attention_mask": audio_attention_mask,
+ }
+ )
+
+ return model_inputs
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ image_pixel_values: Optional[paddle.Tensor] = None,
+ image_sizes: Optional[paddle.Tensor] = None,
+ image_attention_mask=None,
+ audio_input_features: Optional[paddle.Tensor] = None,
+ audio_embed_sizes=None,
+ audio_attention_mask=None,
+ input_mode=None,
+ labels: Optional[paddle.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ attn_mask_startend_row_indices=None,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ image_pixel_values=image_pixel_values,
+ image_sizes=image_sizes,
+ image_attention_mask=image_attention_mask,
+ audio_input_features=audio_input_features,
+ audio_embed_sizes=audio_embed_sizes,
+ audio_attention_mask=audio_attention_mask,
+ input_mode=input_mode,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss = self.criterion(logits, labels)
+ if isinstance(loss, tuple):
+ loss = loss[0]
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+Phi4MultimodalForConditionalGeneration = Phi4MultimodalForCausalLM
+Phi4MMForCausalLM = Phi4MultimodalForCausalLM
+Phi4MMForConditionalGeneration = Phi4MultimodalForCausalLM
+
+
+class Phi4MultimodalForCausalLMPipe(GeneralModelForCausalLMPipe):
+ config_class = Phi4MultimodalConfig
+ _gen_aoa_config = Phi4MultimodalForCausalLM._gen_aoa_config
+ _gen_inv_aoa_config = Phi4MultimodalForCausalLM._gen_inv_aoa_config
+
+
+Phi4MMForCausalLMPipe = Phi4MultimodalForCausalLMPipe
diff --git a/paddleformers/transformers/phi4_multimodal/processor.py b/paddleformers/transformers/phi4_multimodal/processor.py
new file mode 100644
index 00000000000..fbd0340d0f2
--- /dev/null
+++ b/paddleformers/transformers/phi4_multimodal/processor.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Processor class for Phi-4 Multimodal."""
+
+import inspect
+import json
+import os
+import re
+
+from ..feature_extraction_utils import BatchFeature
+from ..processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+
+
+class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "audio_kwargs": {},
+ }
+
+
+class Phi4MultimodalProcessor(ProcessorMixin):
+ attributes = ["image_processor", "feature_extractor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ feature_extractor_class = "AutoFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ @classmethod
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, processor_dict=None, **kwargs):
+ from ..auto.tokenizer import AutoTokenizer
+ from .feature_extraction import Phi4MultimodalFeatureExtractor
+ from .image_processor import Phi4MultimodalImageProcessor
+
+ preprocessor_config = {}
+ preprocessor_config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
+ if os.path.exists(preprocessor_config_path):
+ with open(preprocessor_config_path, encoding="utf-8") as f:
+ preprocessor_config = json.load(f)
+
+ def _filter_config(processor_cls):
+ valid_keys = {
+ key
+ for key, value in inspect.signature(processor_cls.__init__).parameters.items()
+ if key != "self"
+ and value.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
+ }
+ return {key: value for key, value in preprocessor_config.items() if key in valid_keys}
+
+ image_processor = Phi4MultimodalImageProcessor(**_filter_config(Phi4MultimodalImageProcessor))
+ feature_extractor = Phi4MultimodalFeatureExtractor(**_filter_config(Phi4MultimodalFeatureExtractor))
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ return [image_processor, feature_extractor, tokenizer]
+
+ def __init__(
+ self,
+ image_processor=None,
+ feature_extractor=None,
+ tokenizer=None,
+ chat_template=None,
+ audio_processor=None,
+ **kwargs,
+ ):
+ if feature_extractor is None:
+ feature_extractor = audio_processor
+ self.image_token = getattr(tokenizer, "image_token", "<|image|>")
+ self.audio_token = getattr(tokenizer, "audio_token", "<|audio|>")
+ self.image_token_id = getattr(tokenizer, "image_token_id", None)
+ if self.image_token_id is None and tokenizer is not None:
+ self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
+ self.audio_token_id = getattr(tokenizer, "audio_token_id", None)
+ if self.audio_token_id is None and tokenizer is not None:
+ self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token)
+ super().__init__(image_processor, feature_extractor, tokenizer, chat_template=chat_template, **kwargs)
+ self.audio_processor = self.feature_extractor
+
+ def __call__(self, text, images=None, audio=None, **kwargs: Unpack[Phi4MultimodalProcessorKwargs]) -> BatchFeature:
+ output_kwargs = self._merge_kwargs(
+ Phi4MultimodalProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) if images is not None else {}
+ audio_inputs = self.audio_processor(audio, **output_kwargs["audio_kwargs"]) if audio is not None else {}
+
+ num_img_tokens = image_inputs.pop("num_img_tokens", [])
+ audio_embed_sizes = audio_inputs.get("audio_embed_sizes", [])
+ if hasattr(audio_embed_sizes, "numpy"):
+ audio_embed_sizes = audio_embed_sizes.numpy().tolist()
+ elif not isinstance(audio_embed_sizes, list):
+ audio_embed_sizes = list(audio_embed_sizes)
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) or (len(text) > 0 and not isinstance(text[0], str)):
+ raise TypeError("Invalid input text. Please provide a string or a list of strings.")
+
+ concatenated_prompt = "".join(text)
+ if concatenated_prompt.count(self.image_token) != len(num_img_tokens):
+ raise ValueError(
+ "You should add as many image tokens `<|image|>` in your prompt as images passed to the processor. "
+ f"Input contains {concatenated_prompt.count(self.image_token)} tokens != {len(num_img_tokens)} images."
+ )
+ if concatenated_prompt.count(self.audio_token) != len(audio_embed_sizes):
+ raise ValueError(
+ "You should add as many audio tokens `<|audio|>` in your prompt as audios passed to the processor. "
+ f"Input contains {concatenated_prompt.count(self.audio_token)} tokens != {len(audio_embed_sizes)} audios."
+ )
+
+ image_count_iter = iter(num_img_tokens)
+ audio_count_iter = iter(audio_embed_sizes)
+ processed_text = [
+ re.sub(re.escape(self.image_token), lambda _: self.image_token * int(next(image_count_iter)), sample)
+ for sample in text
+ ]
+ processed_text = [
+ re.sub(re.escape(self.audio_token), lambda _: self.audio_token * int(next(audio_count_iter)), sample)
+ for sample in processed_text
+ ]
+
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ text_inputs = self.tokenizer(processed_text, **output_kwargs["text_kwargs"], return_tensors=None)
+ self._check_special_mm_tokens(processed_text, text_inputs, modalities=["image", "audio"])
+
+ if images is not None and audio is not None:
+ input_mode = 3
+ elif images is not None:
+ input_mode = 1
+ elif audio is not None:
+ input_mode = 2
+ else:
+ input_mode = 0
+
+ return BatchFeature(
+ data={**text_inputs, **image_inputs, **audio_inputs, "input_mode": [input_mode]},
+ tensor_type=return_tensors,
+ )
+
+
+__all__ = ["Phi4MultimodalProcessor"]