Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform
结合业界主流优化方法与飞桨在业务实践中积累的高效特性,PaddleFormers 致力于打造**高性能、低资源占用**的训练体验,帮助用户高效便捷地完成大模型训练,而无需关注底层复杂的优化细节。

## 🆕最新更新
* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先Megatron-LM 6%。
* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先 Megatron-LM 6%。
* 2026.01.21 - PaddleFomers v1.0版本发布啦!我们提供了针对 LLM 和 VLM 等模型的训练能力,针对 DeepSeek-V3模型和 GLM-4.5-Air 等重点模型,我们实现了极致性能优化(训练性能明显超越 Megatron-LM )。针对 PaddleOCR-VL,我们在昆仑芯 P800、天数天垓150等国产计算芯片上进行了适配,更好的满足国内用户需求。

## ✨特性
Expand Down Expand Up @@ -102,11 +102,16 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform
</tr>
<!-- VLM 分类 - 跨行合并开始 -->
<tr>
<td rowspan="4" style="vertical-align: top;">VLM</td>
<td rowspan="5" style="vertical-align: top;">VLM</td>
<td>🏛️ERNIE-4.5-VL</td>
<td>baidu/ERNIE-4.5-VL-28B-A3B-Base-PT、baidu/ERNIE-4.5-VL-28B-A3B-PT、baidu/ERNIE-4.5-VL-424B-A47B-Base-PT、baidu/ERNIE-4.5-VL-424B-A47B-PT、baidu/ERNIE-4.5-VL-28B-A3B-Thinking</td>
<td>ernie_vl、ernie_vl_nothink</td>
</tr>
<tr>
<td>Phi-4-multimodal</td>
<td>microsoft/Phi-4-multimodal-instruct</td>
<td>phi4_multimodal</td>
</tr>
<tr>
<td>🏛️PaddleOCR-VL</td>
<td>PaddlePaddle/PaddleOCR-VL</td>
Expand Down
2 changes: 2 additions & 0 deletions docs/zh/model_capability.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
|Qwen3|✓|✓|✓|✓|✓|
|Qwen3-Next|✓|✓|✓|✓|✓|
|🏛️ERNIE-4.5-VL|x|✓|✓|x|x|
|Phi-4-multimodal|x|✓|✓|x|x|
|🏛️PaddleOCR-VL|x|✓|✓|x|x|
|Qwen2.5-VL|x|✓|✓|x|x|
|Qwen3-VL|x|✓|✓|x|x|
Expand All @@ -30,6 +31,7 @@
|Qwen3|✓|✓|✓|✓|✓|✓|
|Qwen3-Next|✓|✓|✓|x|✓|✓|
|🏛️ERNIE-4.5-VL|✓|✓|✓|x|✓|✓|
|Phi-4-multimodal|x|x|-|x|✓|✓|
|🏛️PaddleOCR-VL|x|x|-|x|✓|✓|
|Qwen2.5-VL|✓|x|-|x|✓|✓|
|Qwen3-VL|x|x|✓|x|✓|✓|
Expand Down
23 changes: 23 additions & 0 deletions paddleformers/datasets/SFTDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def _process_pretraining_sequence(self, example, actual_example_num):

# label shift
labels = labels[1:] + [-100]
labels = self._mask_mm_token_labels(tokens, labels)

pos_ids = list(range(len(tokens))) # only pure text, mm_position_ids will be reconstructed in collate.py

Expand Down Expand Up @@ -808,6 +809,7 @@ def _process_sft_sequence(self, example, actual_example_num):

# label shift
labels = labels[1:] + [-100]
labels = self._mask_mm_token_labels(tokens, labels)

pos_ids = list(range(len(tokens)))

Expand Down Expand Up @@ -887,6 +889,27 @@ def _truncate(

return input_ids, labels

def _mask_mm_token_labels(self, tokens, labels):
if not self.use_template or self.template_backend == "jinja" or self.template is None:
return labels

mm_plugin = getattr(self.template, "mm_plugin", None)
if not getattr(mm_plugin, "mask_mm_token_labels", False):
return labels

mm_tokens = [token for token in [mm_plugin.image_token, mm_plugin.audio_token] if token is not None]
if not mm_tokens:
return labels

mm_token_ids = self.tokenizer.convert_tokens_to_ids(mm_tokens)
if not isinstance(mm_token_ids, list):
mm_token_ids = [mm_token_ids]
mm_token_ids = {token_id for token_id in mm_token_ids if token_id is not None}
for i, token in enumerate(tokens):
if token in mm_token_ids or labels[i] in mm_token_ids:
labels[i] = -100
return labels

def _encode_truncated(self, input_ids, labels):
length = self._get_length(input_ids, labels)
if self.max_seq_len is not None and length > self.max_seq_len:
Expand Down
49 changes: 49 additions & 0 deletions paddleformers/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,13 @@ def mm_collate_fn(
input_keys.append("video_grid_thw")
input_keys.append("input_features")
input_keys.append("feature_attention_mask")
input_keys.append("image_pixel_values")
input_keys.append("image_sizes")
input_keys.append("image_attention_mask")
input_keys.append("audio_input_features")
input_keys.append("audio_embed_sizes")
input_keys.append("audio_attention_mask")
input_keys.append("input_mode")

if training_args.num_nextn_predict_layers > 0:
input_keys.append("nbatch_pack_offset")
Expand Down Expand Up @@ -652,6 +659,13 @@ def mm_collate_fn(
video_grid_thw = []
input_features = []
feature_attention_mask = []
image_pixel_values = []
image_sizes = []
image_attention_mask = []
audio_input_features = []
audio_embed_sizes = []
audio_attention_mask = []
input_mode = []
for seq in batch_sequence:
original_token_ids.append(seq.token_ids)
mm_inputs = seq.mm_inputs
Expand All @@ -667,6 +681,20 @@ def mm_collate_fn(
input_features.append(mm_inputs["input_features"])
if "feature_attention_mask" in mm_inputs:
feature_attention_mask.append(mm_inputs["feature_attention_mask"])
if "image_pixel_values" in mm_inputs:
image_pixel_values.append(mm_inputs["image_pixel_values"])
if "image_sizes" in mm_inputs:
image_sizes.append(mm_inputs["image_sizes"])
if "image_attention_mask" in mm_inputs:
image_attention_mask.append(mm_inputs["image_attention_mask"])
if "audio_input_features" in mm_inputs:
audio_input_features.append(mm_inputs["audio_input_features"])
if "audio_embed_sizes" in mm_inputs:
audio_embed_sizes.append(mm_inputs["audio_embed_sizes"])
if "audio_attention_mask" in mm_inputs:
audio_attention_mask.append(mm_inputs["audio_attention_mask"])
if "input_mode" in mm_inputs:
input_mode.append(mm_inputs["input_mode"])
if get_rope_func is not None:
filtered_args = {k: paddle.to_tensor(mm_inputs[k]) for k in func_params if k in mm_inputs}
total_input_ids = paddle.to_tensor([seq.token_ids])
Expand Down Expand Up @@ -712,6 +740,20 @@ def mm_collate_fn(
input_features = paddle.concat(input_features, axis=0)
if len(feature_attention_mask) > 0:
feature_attention_mask = paddle.concat(feature_attention_mask, axis=0)
if len(image_pixel_values) > 0:
image_pixel_values = paddle.concat(image_pixel_values, axis=0)
if len(image_sizes) > 0:
image_sizes = paddle.concat(image_sizes, axis=0)
if len(image_attention_mask) > 0:
image_attention_mask = paddle.concat(image_attention_mask, axis=0)
if len(audio_input_features) > 0:
audio_input_features = paddle.concat(audio_input_features, axis=0)
if len(audio_embed_sizes) > 0:
audio_embed_sizes = paddle.concat(audio_embed_sizes, axis=0)
if len(audio_attention_mask) > 0:
audio_attention_mask = paddle.concat(audio_attention_mask, axis=0)
if len(input_mode) > 0:
input_mode = paddle.concat(input_mode, axis=0)
if get_token_type_func is not None: # ernie45vl
bs_idx_in_rope = 0
padded_position_ids = padded_position_ids.transpose([1, 2, 0])
Expand All @@ -736,6 +778,13 @@ def mm_collate_fn(
video_grid_thw,
input_features,
feature_attention_mask,
image_pixel_values,
image_sizes,
image_attention_mask,
audio_input_features,
audio_embed_sizes,
audio_attention_mask,
input_mode,
]
)

Expand Down
85 changes: 85 additions & 0 deletions paddleformers/datasets/template/mm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,90 @@ def get_mm_inputs(
return self._get_mm_inputs(images, videos, audios, processor, **kwargs)


@dataclass
class Phi4MultimodalPlugin(BasePlugin):
mask_mm_token_labels: bool = True

@staticmethod
def _as_int(value):
if hasattr(value, "item"):
value = value.item()
return int(value)

@override
def _validate_input(self, processor, images, videos, audios) -> None:
if len(videos) != 0:
raise ValueError("Phi-4 multimodal does not support video input in this template.")
if len(images) != 0:
if processor is None:
raise ValueError("Processor was not found for Phi-4 multimodal image input.")
if getattr(processor, "image_processor", None) is None:
raise ValueError("Image processor was not found for Phi-4 multimodal image input.")
if len(audios) != 0:
if processor is None:
raise ValueError("Processor was not found for Phi-4 multimodal audio input.")
if getattr(processor, "feature_extractor", None) is None:
raise ValueError("Audio feature extractor was not found for Phi-4 multimodal audio input.")

@override
def _get_mm_inputs(self, images, videos, audios, processor, **kwargs):
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(processor.image_processor(images, return_tensors="pd"))

if len(audios) != 0:
feature_extractor = getattr(processor, "feature_extractor", None)
sampling_rate = getattr(
feature_extractor, "sampling_rate", getattr(processor, "audio_sampling_rate", 16000)
)
audios = self._regularize_audios(audios, sampling_rate=sampling_rate)["audios"]
mm_inputs.update(
feature_extractor(
audios,
sampling_rate=sampling_rate,
return_attention_mask=True,
return_tensors="pd",
)
)

if len(images) != 0 or len(audios) != 0:
input_mode = 3 if len(images) != 0 and len(audios) != 0 else 1 if len(images) != 0 else 2
mm_inputs["input_mode"] = paddle.to_tensor([input_mode], dtype="int64")

return mm_inputs

@override
def process_messages(self, messages, images, videos, audios, mm_inputs, processor):
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
num_img_tokens = mm_inputs.get("num_img_tokens", [])
audio_embed_sizes = mm_inputs.get("audio_embed_sizes", [])
num_image_tokens, num_audio_tokens = 0, 0

for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_seqlen = self._as_int(num_img_tokens[num_image_tokens]) if self.expand_mm_tokens else 1
content = content.replace(IMAGE_PLACEHOLDER, self.image_token * image_seqlen, 1)
num_image_tokens += 1

while AUDIO_PLACEHOLDER in content:
audio_seqlen = self._as_int(audio_embed_sizes[num_audio_tokens]) if self.expand_mm_tokens else 1
content = content.replace(AUDIO_PLACEHOLDER, self.audio_token * audio_seqlen, 1)
num_audio_tokens += 1

message["content"] = content

self.masked_tokens = [token for token in [self.image_token, self.audio_token] if token is not None]
return messages


@dataclass
class PaddleOCRVLPlugin(BasePlugin):
image_bos_token: str = "<|IMAGE_START|>"
Expand Down Expand Up @@ -1496,6 +1580,7 @@ def process_messages(

PLUGINS = {
"base": BasePlugin,
"phi4_multimodal": Phi4MultimodalPlugin,
"ernie_vl": ErnieVLPlugin,
"qwen2_vl": Qwen2VLPlugin,
"paddleocr_vl": PaddleOCRVLPlugin,
Expand Down
18 changes: 18 additions & 0 deletions paddleformers/datasets/template/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,24 @@ def _get_gpt_oss_prefix():
chat_sep="<|im_end|>",
)

register_template(
name="phi4_multimodal",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
format_assistant=StringFormatter(slots=["{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
format_observation=StringFormatter(
slots=["<|user|>\n<tool_response>\n{{content}}\n</tool_response><|end|>\n<|assistant|>\n"]
),
suffix=["<|end|>"],
chat_sep="<|end|>\n",
auto_add_bos=False,
mm_plugin=get_mm_plugin(
name="phi4_multimodal",
image_token="<|endoftext10|>",
audio_token="<|endoftext11|>",
),
)

register_template(
name="glm_ocr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}\n"]),
Expand Down
22 changes: 22 additions & 0 deletions paddleformers/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,27 @@
"phi3.configuration": ["Phi3Config"],
"phi3.tokenizer": ["Phi3Tokenizer"],
"phi3.modeling": ["Phi3Model", "Phi3ForCausalLM", "Phi3ForCausalLMPipe"],
"phi4_multimodal.configuration": [
"Phi4MultimodalAudioConfig",
"Phi4MultimodalConfig",
"Phi4MultimodalVisionConfig",
],
"phi4_multimodal.feature_extraction": ["Phi4MultimodalFeatureExtractor"],
"phi4_multimodal.image_processor": ["Phi4MultimodalImageProcessor"],
"phi4_multimodal.modeling": [
"Phi4Multimodal",
"Phi4MultimodalAudio",
"Phi4MultimodalAudioModel",
"Phi4MultimodalAudioPreTrainedModel",
"Phi4MultimodalForCausalLM",
"Phi4MultimodalForConditionalGeneration",
"Phi4MultimodalModel",
"Phi4MultimodalPreTrainedModel",
"Phi4MultimodalVision",
"Phi4MultimodalVisionModel",
"Phi4MultimodalVisionPreTrainedModel",
],
"phi4_multimodal.processor": ["Phi4MultimodalProcessor"],
"glm4v_moe.configuration": ["Glm4vMoeConfig", "Glm4vMoeTextConfig", "Glm4vMoeVisionConfig"],
"glm4v_moe.modeling": [
"Glm4vMoeForConditionalGeneration",
Expand Down Expand Up @@ -408,6 +429,7 @@
from .minimax_m2 import *
from .gpt_oss import *
from .phi3 import *
from .phi4_multimodal import *
from .gemma3_text import *
from .glm_ocr import *
else:
Expand Down
10 changes: 10 additions & 0 deletions paddleformers/transformers/auto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@
("minimax_m2", "MiniMaxM2Config"),
("gpt_oss", "GptOssConfig"),
("phi3", "Phi3Config"),
("phi4mm", "Phi4MultimodalConfig"),
("phi4_multimodal", "Phi4MultimodalConfig"),
("phi4_multimodal_audio", "Phi4MultimodalAudioConfig"),
("phi4_multimodal_vision", "Phi4MultimodalVisionConfig"),
("gemma3_text", "Gemma3TextConfig"),
("glm4v_moe", "Glm4vMoeConfig"),
("glm_ocr", "GlmOcrConfig"),
Expand Down Expand Up @@ -87,6 +91,9 @@
("qwen3_vl_moe", "Qwen3VLMoe"),
("qwen3_vl_moe_text", "Qwen3VLMoeText"),
("glm_ocr", "GlmOcrForConditionalGeneration"),
("phi4_multimodal", "Phi4Multimodal"),
("phi4_multimodal_audio", "Phi4MultimodalAudio"),
("phi4_multimodal_vision", "Phi4MultimodalVision"),
("qwen3_5_moe", "Qwen3_5MoEForConditionalGeneration"),
("qwen3_5", "Qwen3_5ForConditionalGeneration"),
]
Expand All @@ -100,6 +107,9 @@
SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
[
("qwen2_5_vl_text", "qwen2_5_vl"),
("phi4_multimodal_audio", "phi4_multimodal"),
("phi4_multimodal_vision", "phi4_multimodal"),
("phi4mm", "phi4_multimodal"),
("qwen3_vl_text", "qwen3_vl"),
("qwen3_vl_moe_text", "qwen3_vl_moe"),
]
Expand Down
4 changes: 4 additions & 0 deletions paddleformers/transformers/auto/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
[
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
("whisper", "WhisperFeatureExtractor"),
]
)
Expand All @@ -56,6 +57,9 @@ def safe_load_json_file(json_file: str):


def feature_extractor_class_from_name(class_name: str):
if class_name == "Phi4MMAudioFeatureExtractor":
class_name = "Phi4MultimodalFeatureExtractor"

for module_name, extractors in FEATURE_EXTRACTOR_MAPPING_NAMES.items():
if class_name in extractors:
module_name = model_type_to_module_name(module_name)
Expand Down
4 changes: 4 additions & 0 deletions paddleformers/transformers/auto/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"glm4v_moe": ("Glm4vImageProcessor", "Glm4vImageProcessorFast"),
"kimi_k25": ("KimiK25VisionProcessor"),
"paddleocr_vl": ("PaddleOCRVLImageProcessor"),
"phi4_multimodal": ("Phi4MultimodalImageProcessor"),
"qwen2_5_vl": ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast"),
"qwen2_vl": ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast"),
"qwen3_vl": ("Qwen3VLImageProcessor", "Qwen3VLImageProcessorFast"),
Expand All @@ -68,6 +69,9 @@


def get_image_processor_class_from_name(class_name: str):
if class_name == "Phi4MMImageProcessor":
class_name = "Phi4MultimodalImageProcessor"

if class_name == "BaseImageProcessorFast":
return BaseImageProcessorFast

Expand Down
Loading
Loading