diff --git a/README.md b/README.md index 76e7ea002a5..f610ba184bf 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform 结合业界主流优化方法与飞桨在业务实践中积累的高效特性,PaddleFormers 致力于打造**高性能、低资源占用**的训练体验,帮助用户高效便捷地完成大模型训练,而无需关注底层复杂的优化细节。 ## 🆕最新更新 -* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先Megatron-LM 6%。 +* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先 Megatron-LM 6%。 * 2026.01.21 - PaddleFomers v1.0版本发布啦!我们提供了针对 LLM 和 VLM 等模型的训练能力,针对 DeepSeek-V3模型和 GLM-4.5-Air 等重点模型,我们实现了极致性能优化(训练性能明显超越 Megatron-LM )。针对 PaddleOCR-VL,我们在昆仑芯 P800、天数天垓150等国产计算芯片上进行了适配,更好的满足国内用户需求。 ## ✨特性 @@ -50,7 +50,7 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform - LLM + LLM DeepSeekv3 deepseek-ai/DeepSeek-V3-Base、deepseek-ai/DeepSeek-V3、deepseek-ai/DeepSeek-V3-0324 deepseek3 @@ -80,6 +80,11 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform meta-llama/Meta-Llama-3-8B、meta-llama/Meta-Llama-3-8B-Instruct、meta-llama/Meta-Llama-3-70B、meta-llama/Meta-Llama-3-70B-Instruct、meta-llama/Llama-3.1-8B、meta-llama/Llama-3.1-8B-Instruct、meta-llama/Llama-3.1-70B、meta-llama/Llama-3.1-70B-Instruct、meta-llama/Llama-3.1-405B、meta-llama/Llama-3.1-405B-Instruct、meta-llama/Llama-3.2-1B、meta-llama/Llama-3.2-1B-Instruct、meta-llama/Llama-3.2-3B、meta-llama/Llama-3.2-3B-Instruct、meta-llama/Llama-3.3-70B-Instruct llama3 + + MiniMax-Text-01 + MiniMaxAI/MiniMax-Text-01 + minimax + phi-4 microsoft/phi-4 diff --git a/docs/zh/model_capability.md b/docs/zh/model_capability.md index 3b402f9d4b2..7e2b0b21374 100644 --- a/docs/zh/model_capability.md +++ b/docs/zh/model_capability.md @@ -7,6 +7,7 @@ |GLM-4.5|✓|✓|✓|✓|✓| |GPT-OSS|✓|✓|✓|x|x| |LLaMA3|✓|✓|✓|✓|✓| +|MiniMax-Text-01|✓|✓|✓|x|x| |Phi4|✓|✓|✓|✓|✓| |Qwen2|✓|✓|✓|✓|✓| |Qwen3|✓|✓|✓|✓|✓| @@ -25,6 +26,7 @@ |GLM-4.5|✓|✓|✓|✓|✓|✓| |GPT-OSS|✓|✓|x|x|✓|✓| |LLaMA3|✓|✓|-|x|✓|✓| +|MiniMax-Text-01|x|x|x|x|✓|✓| |Phi4|✓|✓|-|x|✓|✓| |Qwen2|✓|✓|x|x|✓|✓| |Qwen3|✓|✓|✓|✓|✓|✓| diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py index 6d0d42cd39c..5e4f4f4d546 100644 --- a/paddleformers/datasets/template/template.py +++ b/paddleformers/datasets/template/template.py @@ -638,6 +638,22 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template": thought_words=("\n", "\n\n\n"), ) +register_template( + name="minimax", + format_user=StringFormatter( + slots=[ + "user name=user\n{{content}}\n" + "ai name=assistant\n" + ] + ), + format_assistant=StringFormatter(slots=["{{content}}"]), + format_system=StringFormatter( + slots=["system ai_setting=assistant\n{{content}}\n"] + ), + chat_sep="\n", + suffix=[""], +) + register_template( name="paddleocr_vl", format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]), diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index ae1c4f20c73..068135cff2b 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -313,6 +313,8 @@ "minimax_m2": ["MiniMaxM2ForCausalLMPipe", "MiniMaxM2ForCausalLM"], "deepseek_v4.configuration": ["DeepseekV4Config"], "deepseek_v4": ["DeepseekV4ForCausalLMPipe", "DeepseekV4ForCausalLM"], + "minimax.configuration": ["MiniMaxConfig"], + "minimax": ["MiniMaxModel", "MiniMaxForCausalLM", "MiniMaxForCausalLMPipe"], "glm4v_moe.image_processor": ["Glm4vImageProcessor"], "glm4v_moe.image_processor_fast": ["Glm4vImageProcessorFast"], "auto": ["AutoModelForCausalLM"], @@ -414,6 +416,7 @@ from .glm_moe_dsa import * from .minimax_m2 import * from .deepseek_v4 import * + from .minimax import * from .gpt_oss import * from .phi3 import * from .gemma3_text import * diff --git a/paddleformers/transformers/auto/configuration.py b/paddleformers/transformers/auto/configuration.py index c04e1f34a5a..296e3eb6757 100644 --- a/paddleformers/transformers/auto/configuration.py +++ b/paddleformers/transformers/auto/configuration.py @@ -54,6 +54,7 @@ ("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"), ("glm4_moe", "Glm4MoeConfig"), ("glm_moe_dsa", "GlmMoeDsaConfig"), + ("minimax", "MiniMaxConfig"), ("minimax_m2", "MiniMaxM2Config"), ("deepseek_v4", "DeepseekV4Config"), ("gpt_oss", "GptOssConfig"), @@ -89,6 +90,7 @@ ("qwen3_vl_moe", "Qwen3VLMoe"), ("qwen3_vl_moe_text", "Qwen3VLMoeText"), ("glm_ocr", "GlmOcrForConditionalGeneration"), + ("minimax", "MiniMaxForCausalLM"), ("qwen3_5_moe", "Qwen3_5MoEForConditionalGeneration"), ("qwen3_5", "Qwen3_5ForConditionalGeneration"), ] diff --git a/paddleformers/transformers/auto/modeling.py b/paddleformers/transformers/auto/modeling.py index 11321baba1f..ea5a350fbdf 100644 --- a/paddleformers/transformers/auto/modeling.py +++ b/paddleformers/transformers/auto/modeling.py @@ -73,6 +73,7 @@ ("Qwen3_5", "qwen3_5"), ("Glm4Moe", "glm4_moe"), ("GlmMoeDsa", "glm_moe_dsa"), + ("MiniMax", "minimax"), ("MiniMaxM2", "minimax_m2"), ("DeepseekV4", "deepseek_v4"), ("GptOss", "gpt_oss"), diff --git a/paddleformers/transformers/minimax/__init__.py b/paddleformers/transformers/minimax/__init__.py new file mode 100644 index 00000000000..0e968840f32 --- /dev/null +++ b/paddleformers/transformers/minimax/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Package""" +import sys +from typing import TYPE_CHECKING + +from ...utils.lazy_import import _LazyModule + +import_structure = { + "configuration": ["MiniMaxConfig"], + "modeling": [ + "MiniMaxPretrainedModel", + "MiniMaxModel", + "MiniMaxForCausalLM", + "MiniMaxForCausalLMPipe", + ], +} + +if TYPE_CHECKING: + from .configuration import * + from .modeling import * +else: + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + ) diff --git a/paddleformers/transformers/minimax/configuration.py b/paddleformers/transformers/minimax/configuration.py new file mode 100644 index 00000000000..70697549671 --- /dev/null +++ b/paddleformers/transformers/minimax/configuration.py @@ -0,0 +1,224 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MiniMax (Text-01) model configuration""" + +from ..configuration_utils import PretrainedConfig +from ..modeling_rope_utils import rope_config_validation, standardize_rope_params + + +class MiniMaxConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate a + MiniMax (Text-01) model according to the specified arguments, defining the model architecture. + + MiniMax (Text-01) uses a hybrid attention mechanism: + - Some layers are full attention (standard causal self-attention with RoPE) + - Some layers are linear attention ("lightning attention") with intra-/inter-block attention + + The layer type is controlled by `layer_types`. By default, the pattern is alternating + `full_attention` and `linear_attention` (full on odd-indexed layers, linear on even). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 200064): + Vocabulary size of the MiniMax model. + hidden_size (`int`, *optional*, defaults to 6144): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 9216): + Dimension of the routed expert MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 80): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 64): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + Number of key_value heads for implementing Grouped Query Attention. + head_dim (`int`, *optional*): + Dimension of each attention head. If None, defaults to hidden_size // num_attention_heads. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 10240000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + Number of selected experts per token. + num_local_experts (`int`, *optional*, defaults to 32): + Number of routed experts. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether the router logits should be returned by the model. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + Coefficient for the auxiliary load-balancing loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Jitter noise for the router. + layer_types (`list[str]`, *optional*): + A list that maps each layer index to its attention type. Can be `"full_attention"` or `"linear_attention"`. + block_size (`int`, *optional*, defaults to 256): + The length of each attention block for the lightning attention. + full_attn_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after full attention. + full_attn_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after full attention. + linear_attn_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after lightning attention. + linear_attn_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after lightning attention. + mlp_alpha_factor (`float`, *optional*, defaults to 1): + Weight for residual value in residual connection after MLP. + mlp_beta_factor (`float`, *optional*, defaults to 1): + Weight for hidden state value in residual connection after MLP. + + ```python + >>> from paddleformers.transformers import MiniMaxModel, MiniMaxConfig + + >>> # Initializing a MiniMax (Text-01) style configuration + >>> configuration = MiniMaxConfig() + + >>> # Initializing a model from the MiniMax (Text-01) style configuration + >>> model = MiniMaxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minimax" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=200064, + hidden_size=6144, + intermediate_size=9216, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=10240000, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + tie_word_embeddings=False, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=32, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + attn_type_list=None, + rope_theta=10000000.0, + rope_scaling=None, + layer_types=None, + block_size=256, + full_attn_alpha_factor=1.0, + full_attn_beta_factor=1.0, + linear_attn_alpha_factor=1.0, + linear_attn_beta_factor=1.0, + mlp_alpha_factor=1.0, + mlp_beta_factor=1.0, + **kwargs, + ): + full_attn_alpha_factor = kwargs.pop("layernorm_full_attention_alpha", full_attn_alpha_factor) + full_attn_beta_factor = kwargs.pop("layernorm_full_attention_beta", full_attn_beta_factor) + linear_attn_alpha_factor = kwargs.pop("layernorm_linear_attention_alpha", linear_attn_alpha_factor) + linear_attn_beta_factor = kwargs.pop("layernorm_linear_attention_beta", linear_attn_beta_factor) + mlp_alpha_factor = kwargs.pop("layernorm_mlp_alpha", mlp_alpha_factor) + mlp_beta_factor = kwargs.pop("layernorm_mlp_beta", mlp_beta_factor) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_dropout = attention_dropout + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + if self.rope_scaling is not None and "type" in self.rope_scaling: + self.rope_scaling["rope_type"] = self.rope_scaling["type"] + self.rope_parameters = self.rope_scaling + standardize_rope_params(self, rope_theta=rope_theta) + rope_config_validation(self) + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + + if layer_types is None: + if attn_type_list is not None: + if len(attn_type_list) != num_hidden_layers: + raise ValueError( + f"attn_type_list length ({len(attn_type_list)}) must equal " + f"num_hidden_layers ({num_hidden_layers})." + ) + self.layer_types = [ + "linear_attention" if int(attn_type) == 0 else "full_attention" for attn_type in attn_type_list + ] + else: + self.layer_types = [ + "full_attention" if bool((i + 1) % 2) else "linear_attention" + for i in range(self.num_hidden_layers) + ] + else: + if len(layer_types) != num_hidden_layers: + raise ValueError( + f"layer_types length ({len(layer_types)}) must equal num_hidden_layers ({num_hidden_layers})." + ) + self.layer_types = list(layer_types) + self.attn_type_list = [0 if layer_type == "linear_attention" else 1 for layer_type in self.layer_types] + self.block_size = block_size + self.full_attn_alpha_factor = full_attn_alpha_factor + self.full_attn_beta_factor = full_attn_beta_factor + self.linear_attn_alpha_factor = linear_attn_alpha_factor + self.linear_attn_beta_factor = linear_attn_beta_factor + self.mlp_alpha_factor = mlp_alpha_factor + self.mlp_beta_factor = mlp_beta_factor + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + sliding_window=sliding_window, + **kwargs, + ) + + +__all__ = ["MiniMaxConfig"] diff --git a/paddleformers/transformers/minimax/modeling.py b/paddleformers/transformers/minimax/modeling.py new file mode 100644 index 00000000000..504019a860f --- /dev/null +++ b/paddleformers/transformers/minimax/modeling.py @@ -0,0 +1,1066 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Paddle MiniMax (Text-01) model.""" + +from __future__ import annotations + +import os +from typing import Callable + +import paddle +import paddle.nn.functional as F +from paddle import nn + +from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS +from ...nn.criterion.interface import CriterionLayer +from ...nn.embedding import Embedding as GeneralEmbedding +from ...nn.linear import Linear as GeneralLinear +from ...nn.lm_head import LMHead as GeneralLMHead +from ...nn.pp_model import GeneralModelForCausalLMPipe +from ...utils.log import logger +from ..cache_utils import Cache, DynamicCache +from ..masking_utils import create_causal_mask_and_row_indices +from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast +from ..model_utils import PretrainedModel, register_base_model +from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from .configuration import MiniMaxConfig + + +def rotate_half(x: paddle.Tensor) -> paddle.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.cat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb( + q: paddle.Tensor, + k: paddle.Tensor, + cos: paddle.Tensor, + sin: paddle.Tensor, + position_ids: paddle.Tensor | None = None, + unsqueeze_dim: int = 1, +) -> tuple[paddle.Tensor, paddle.Tensor]: + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q: query tensor with shape [..., head_dim] + k: key tensor with shape [..., head_dim] + cos: cosine values + sin: sine values + unsqueeze_dim: dimension to unsqueeze cos/sin for broadcasting + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.astype(q.dtype), k_embed.astype(k.dtype) + + +class MiniMaxRMSNorm(nn.Layer): + """RMSNorm used in MiniMax (Text-01). Equivalent to T5LayerNorm.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = self.create_parameter( + shape=[hidden_size], + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = eps + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * paddle.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.astype(input_dtype) + + def extra_repr(self) -> str: + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MiniMaxCache(DynamicCache): + """Cache for MiniMax that supports both standard KV-cache and linear-attention (KV-statistic) cache.""" + + def __init__(self, config: MiniMaxConfig): + super().__init__(config=config) + self.linear_cache: list[paddle.Tensor] = [] + + def set_linear_cache(self, layer_idx: int, linear_cache: paddle.Tensor): + for _ in range(len(self.linear_cache), layer_idx + 1): + self.linear_cache.append(None) + self.linear_cache[layer_idx] = linear_cache + + def get_linear_cache(self, layer_idx: int): + if layer_idx < len(self.linear_cache): + return self.linear_cache[layer_idx] + return None + + def __len__(self): + return max(super().__len__(), len(self.linear_cache)) + + def crop(self, max_length: int): + raise RuntimeError("MiniMaxCache does not support `crop` method") + + +class MiniMaxLightningAttention(nn.Layer): + """Linear attention ("lightning attention") for MiniMax. + + Operates on intra-block (within the current block) and inter-block (using cached KV statistics) + components. Statistics are computed by a gated linear unit (GLU) on a fused QKV projection. + """ + + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + self.num_hidden_layers = config.num_hidden_layers + self.block_size = config.block_size + + hidden_size = config.hidden_size + qkv_out = self.num_attention_heads * self.head_dim * 3 + attn_out = self.num_attention_heads * self.head_dim + + self.qkv_proj = GeneralLinear.create( + hidden_size, + qkv_out, + has_bias=False, + config=config, + tp_plan="colwise", + ) + self.out_proj = GeneralLinear.create( + attn_out, + hidden_size, + has_bias=False, + config=config, + tp_plan="rowwise", + ) + self.output_gate = GeneralLinear.create( + hidden_size, + attn_out, + has_bias=False, + config=config, + tp_plan="colwise", + ) + self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads, eps=config.rms_norm_eps) + + self.act_fn = F.silu + + slope_rate = self.get_slope_rate() + query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate) + + self.register_buffer("slope_rate", slope_rate, persistable=False) + self.register_buffer("query_decay", query_decay, persistable=False) + self.register_buffer("key_decay", key_decay, persistable=False) + self.register_buffer("diagonal_decay", diagonal_decay, persistable=False) + + def get_slope_rate(self) -> paddle.Tensor: + base = 1.0 / (2.0 ** (8.0 / self.num_attention_heads)) + dtype = paddle.get_default_dtype() + exponent = paddle.arange(self.num_attention_heads).astype(dtype) + 1 + factor = 1.0 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5 + + rate = paddle.pow(paddle.to_tensor(base, dtype=dtype), exponent) + rate = rate * factor + rate = rate.unsqueeze(-1).unsqueeze(-1) + return rate + + def decay_factors(self, slope_rate: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + block_size_range = paddle.arange(self.block_size).astype(slope_rate.dtype) + 1 + + query_decay = paddle.exp(-slope_rate * block_size_range.unsqueeze(-1)) + key_decay = paddle.exp(-slope_rate * (self.block_size - block_size_range.unsqueeze(-1))) + + diff = block_size_range.unsqueeze(-1) - block_size_range.unsqueeze(0) + diff = diff.unsqueeze(0).unsqueeze(0) + decay = slope_rate * diff + neg_inf = paddle.full_like(decay, fill_value=float("-inf")) + decay = paddle.where(decay >= 0, -decay, neg_inf) + diagonal_decay = paddle.exp(decay) + + return query_decay, key_decay, diagonal_decay + + def forward( + self, + hidden_states: paddle.Tensor, + position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None, + attention_mask: paddle.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool = False, + **kwargs, + ) -> tuple[paddle.Tensor, paddle.Tensor | None]: + batch_size, seq_len, _ = hidden_states.shape + num_blocks = (seq_len + self.block_size - 1) // self.block_size + slope_rate = self.slope_rate.astype(hidden_states.dtype) + + qkv_states = self.act_fn(self.qkv_proj(hidden_states)) + qkv_states = qkv_states.reshape([batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim]) + query_states, key_states, value_states = paddle.split(qkv_states, num_or_sections=3, axis=-1) + + query_states = query_states.transpose([0, 2, 1, 3]) + key_states = key_states.transpose([0, 2, 1, 3]) + value_states = value_states.transpose([0, 2, 1, 3]) + + attn_weights_inter = None + if past_key_values is not None and isinstance(past_key_values, MiniMaxCache): + attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx) + + if attn_weights_inter is None: + attn_weights_inter = paddle.zeros( + [batch_size, self.num_attention_heads, self.head_dim, self.head_dim], + dtype=hidden_states.dtype, + ) + + if attention_mask is not None: + bool_mask = attention_mask.astype("bool") + expanded = bool_mask.unsqueeze(1).unsqueeze(-1) + value_states = paddle.where(expanded, value_states, paddle.zeros_like(value_states)) + + attn_output = [] + for i in range(num_blocks): + start_idx = i * self.block_size + end_idx = min(start_idx + self.block_size, seq_len) + cur_bs = end_idx - start_idx + + cur_q = query_states[:, :, start_idx:end_idx] + cur_k = key_states[:, :, start_idx:end_idx] + cur_v = value_states[:, :, start_idx:end_idx] + + cur_qd = self.query_decay[:, :cur_bs].astype(cur_q.dtype) + cur_kd = self.key_decay[:, -cur_bs:].astype(cur_k.dtype) + cur_dd = self.diagonal_decay[:, :, :cur_bs, :cur_bs].astype(cur_q.dtype) + block_decay = paddle.exp(-slope_rate * cur_bs).astype(attn_weights_inter.dtype) + + attn_intra = paddle.matmul(cur_q, cur_k.transpose([0, 1, 3, 2])) + attn_output_intra = paddle.matmul(attn_intra * cur_dd, cur_v) + + attn_output_inter = paddle.matmul(cur_q * cur_qd, attn_weights_inter) + + cur_out = attn_output_inter + attn_output_intra + attn_output.append(cur_out) + + next_stat = paddle.matmul((cur_k * cur_kd).transpose([0, 1, 3, 2]), cur_v) + attn_weights_inter = attn_weights_inter * block_decay + next_stat + else: + ratio = paddle.exp(-slope_rate).astype(attn_weights_inter.dtype) + attn_output = [] + for i in range(seq_len): + cur_q = query_states[:, :, i : i + 1] + cur_k = key_states[:, :, i : i + 1] + cur_v = value_states[:, :, i : i + 1] + + cur_stat = paddle.matmul(cur_k.transpose([0, 1, 3, 2]), cur_v) + attn_weights_inter = ratio * attn_weights_inter + cur_stat + cur_out = paddle.matmul(cur_q, attn_weights_inter) + attn_output.append(cur_out) + + attn_output = paddle.cat(attn_output, axis=-2) + + attn_output = attn_output.transpose([0, 2, 1, 3]) + attn_output = attn_output.reshape([batch_size, seq_len, self.num_attention_heads * self.head_dim]) + attn_output = self.norm(attn_output) + attn_output = F.sigmoid(self.output_gate(hidden_states)).astype(attn_output.dtype) * attn_output + attn_output = attn_output.astype(hidden_states.dtype) + attn_output = self.out_proj(attn_output) + + if past_key_values is not None and isinstance(past_key_values, MiniMaxCache): + past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter) + + return attn_output, attn_weights_inter + + +class MiniMaxRotaryEmbedding(nn.Layer): + inv_freq: paddle.Tensor # for `register_buffer` typing + + def __init__(self, config: MiniMaxConfig): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + if hasattr(config, "rope_parameters") and isinstance(config.rope_parameters, dict): + self.rope_type = config.rope_parameters.get("rope_type", "default") + else: + self.rope_type = "default" + + rope_init_fn = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config) + + self.register_buffer("inv_freq", inv_freq, persistable=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: MiniMaxConfig | None = None, + seq_len: int | None = None, + ) -> tuple["paddle.Tensor", float]: + """Compute default RoPE inverse frequencies.""" + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + attention_factor = 1.0 + inv_freq = 1.0 / (base ** (paddle.arange(0, dim, 2, dtype=paddle.int64).astype(dtype=paddle.float32) / dim)) + return inv_freq, attention_factor + + @dynamic_rope_update + def forward(self, x: paddle.Tensor, position_ids: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor]: + with paddle.amp.auto_cast(enable=False): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand([position_ids.shape[0], -1, 1]).to(x.dtype) + position_ids_expanded = position_ids[:, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose([0, 2, 1]) + emb = paddle.concat((freqs, freqs), axis=-1) + + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.astype(x.dtype), sin.astype(x.dtype) + + +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """Repeat KV heads n_rep times to match the number of query heads.""" + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand([batch, num_key_value_heads, n_rep, slen, head_dim]) + return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim]) + + +def eager_attention_forward( + module: nn.Layer, + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + attention_mask: paddle.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs, +): + """Standard (eager) attention forward implementation.""" + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = paddle.matmul(query, key_states.transpose([0, 1, 3, 2])) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query.dtype) + attn_weights = attn_weights * (1.0 - dropout) if dropout > 0 else attn_weights + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]).contiguous() + return attn_output, attn_weights + + +class MiniMaxAttention(nn.Layer): + """Multi-headed attention (full attention) from MiniMax. + + Includes Grouped Query Attention (GQA) when num_key_value_heads != num_attention_heads. + """ + + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + + self.q_proj = GeneralLinear.create( + config.hidden_size, + self.num_attention_heads * self.head_dim, + has_bias=False, + config=config, + tp_plan="colwise", + ) + self.k_proj = GeneralLinear.create( + config.hidden_size, + self.num_key_value_heads * self.head_dim, + has_bias=False, + config=config, + tp_plan="colwise", + ) + self.v_proj = GeneralLinear.create( + config.hidden_size, + self.num_key_value_heads * self.head_dim, + has_bias=False, + config=config, + tp_plan="colwise", + ) + self.o_proj = GeneralLinear.create( + self.num_attention_heads * self.head_dim, + config.hidden_size, + has_bias=False, + config=config, + tp_plan="rowwise", + ) + + def forward( + self, + hidden_states: paddle.Tensor, + position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None, + attention_mask: paddle.Tensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool = False, + attn_mask_startend_row_indices: paddle.Tensor | None = None, + **kwargs, + ) -> tuple[paddle.Tensor, paddle.Tensor | None]: + bsz, q_len, _ = hidden_states.shape + hidden_shape = [bsz, q_len, -1, self.head_dim] + + query_states = self.q_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + key_states = self.k_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + value_states = self.v_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3]) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + attention_interface: Callable = eager_attention_forward + if getattr(self.config, "_attn_implementation", "eager") != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + ) + + attn_output = attn_output.reshape([bsz, q_len, -1]).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class MiniMaxTopKRouter(nn.Layer): + """Top-K router for MiniMax MoE.""" + + def __init__(self, config: MiniMaxConfig): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.weight = self.create_parameter( + shape=[self.hidden_dim, self.num_experts], + dtype=paddle.get_default_dtype(), + is_bias=False, + ) + + def forward(self, hidden_states: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + router_logits = F.linear(hidden_states, self.weight) + router_logits = F.softmax(router_logits.astype("float32"), axis=-1) + router_top_value, router_indices = paddle.topk(router_logits, self.top_k, axis=-1) + router_top_value = router_top_value / router_top_value.sum(axis=-1, keepdim=True) + router_top_value = router_top_value.astype(hidden_states.dtype) + return router_logits, router_top_value, router_indices + + +class MiniMaxBlockSparseTop2MLP(nn.Layer): + def __init__(self, config: MiniMaxConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.w1 = GeneralLinear.create( + self.hidden_dim, + self.intermediate_dim, + has_bias=False, + config=config, + tp_plan="colwise", + ) + self.w2 = GeneralLinear.create( + self.intermediate_dim, + self.hidden_dim, + has_bias=False, + config=config, + tp_plan="rowwise", + ) + self.w3 = GeneralLinear.create( + self.hidden_dim, + self.intermediate_dim, + has_bias=False, + config=config, + tp_plan="colwise", + ) + self.act_fn = F.silu + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + return self.w2(hidden_states) + + +class MiniMaxSparseMoeBlock(nn.Layer): + """Sparse MoE block (router + experts) for MiniMax.""" + + def __init__(self, config: MiniMaxConfig): + super().__init__() + self.top_k = config.num_experts_per_tok + self.jitter_noise = config.router_jitter_noise + self.gate = MiniMaxTopKRouter(config) + self.experts = nn.LayerList([MiniMaxBlockSparseTop2MLP(config) for _ in range(config.num_local_experts)]) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states = hidden_states * paddle.uniform( + hidden_states.shape, + dtype=hidden_states.dtype, + min=1.0 - self.jitter_noise, + max=1.0 + self.jitter_noise, + ) + hidden_states_flat = hidden_states.reshape([-1, hidden_states.shape[-1]]) + final_hidden_states = paddle.zeros_like(hidden_states_flat) + _, routing_weights, selected_experts = self.gate(hidden_states_flat) + + with paddle.no_grad(): + expert_mask = F.one_hot(selected_experts, num_classes=self.gate.num_experts) + expert_mask = expert_mask.transpose([2, 1, 0]) + + fixed_expert_order = os.environ.get("PADDLE_MINIMAX_MOE_FIXED_EXPERT_ORDER", "0") == "1" + connect_empty_experts = os.environ.get("PADDLE_MINIMAX_MOE_CONNECT_EMPTY_EXPERTS", "1") == "1" + expert_indices = ( + range(self.gate.num_experts) + if fixed_expert_order + else [ + int(expert_idx[0].item()) + for expert_idx in paddle.greater( + expert_mask.sum(axis=(-1, -2)), + paddle.to_tensor(0, dtype="int64"), + ).nonzero() + ] + ) + + for expert_idx in expert_indices: + top_k_pos, token_idx = paddle.where(expert_mask[expert_idx]) + current_state = hidden_states_flat[token_idx] + current_hidden_states = self.experts[expert_idx](current_state) + if token_idx.shape[0] > 0: + current_hidden_states = current_hidden_states * routing_weights[token_idx, top_k_pos, None] + final_hidden_states = final_hidden_states.index_add_( + axis=0, + index=token_idx, + value=current_hidden_states.astype(final_hidden_states.dtype), + ) + elif fixed_expert_order and connect_empty_experts: + final_hidden_states = ( + final_hidden_states + current_hidden_states.sum().astype(final_hidden_states.dtype) * 0.0 + ) + + return final_hidden_states.reshape([batch_size, sequence_length, hidden_dim]) + + +class MiniMaxDecoderLayer(nn.Layer): + """A single decoder layer. Selects between full and linear attention by `config.layer_types[layer_idx]`.""" + + def __init__(self, config: MiniMaxConfig, layer_idx: int): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.mlp_alpha_factor = config.mlp_alpha_factor + self.mlp_beta_factor = config.mlp_beta_factor + + self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if self.layer_type == "linear_attention": + self.self_attn = MiniMaxLightningAttention(config, layer_idx) + self.attn_alpha_factor = config.linear_attn_alpha_factor + self.attn_beta_factor = config.linear_attn_beta_factor + elif self.layer_type == "full_attention": + self.self_attn = MiniMaxAttention(config, layer_idx) + self.attn_alpha_factor = config.full_attn_alpha_factor + self.attn_beta_factor = config.full_attn_beta_factor + else: + raise ValueError( + f"Unknown layer_type '{self.layer_type}'. Expected 'full_attention' or 'linear_attention'." + ) + + self.block_sparse_moe = MiniMaxSparseMoeBlock(config) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: paddle.Tensor | None = None, + position_ids: paddle.Tensor | None = None, + position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + attn_mask_startend_row_indices: paddle.Tensor | None = None, + **kwargs, + ) -> paddle.Tensor: + hidden_states = self.input_layernorm(hidden_states) + residual = hidden_states + + attn_outputs = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor + + hidden_states = self.post_attention_layernorm(hidden_states) + residual = hidden_states + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor + + return hidden_states + + +def load_balancing_loss_func( + gate_logits: tuple[paddle.Tensor] | list[paddle.Tensor] | None, + num_experts: int | None = None, + top_k: int = 2, + attention_mask: paddle.Tensor | None = None, +) -> paddle.Tensor | int: + """Auxiliary load-balancing loss for MoE.""" + if gate_logits is None or not isinstance(gate_logits, (tuple, list)): + return 0 + + concatenated_gate_logits = paddle.cat([layer_gate for layer_gate in gate_logits], axis=0) + + routing_weights = F.softmax(concatenated_gate_logits, axis=-1) + _, selected_experts = paddle.topk(routing_weights, top_k, axis=-1) + expert_mask = F.one_hot(selected_experts, num_experts) + + if attention_mask is None: + tokens_per_expert = paddle.mean(expert_mask.astype("float32"), axis=0) + router_prob_per_expert = paddle.mean(routing_weights, axis=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand([num_hidden_layers, batch_size, sequence_length, top_k, num_experts]) + .reshape([-1, top_k, num_experts]) + ) + tokens_per_expert = paddle.sum(expert_mask.astype("float32") * expert_attention_mask, axis=0) / paddle.sum( + expert_attention_mask, axis=0 + ) + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand([num_hidden_layers, batch_size, sequence_length, num_experts]) + .reshape([-1, num_experts]) + ) + router_prob_per_expert = paddle.sum(routing_weights * router_per_expert_attention_mask, axis=0) / paddle.sum( + router_per_expert_attention_mask, axis=0 + ) + + overall_loss = paddle.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class MiniMaxPretrainedModel(PretrainedModel): + config_class = MiniMaxConfig + config: MiniMaxConfig + base_model_prefix = "model" + + transpose_weight_keys = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "gate", + "w1", + "w2", + "w3", + "output_gate", + "qkv_proj", + "out_proj", + ] + + @classmethod + def _gen_aoa_config(cls, config: MiniMaxConfig): + """AOA config: mapping from HF safetensors key to PaddleFormers model key.""" + if os.environ.get("PADDLE_MINIMAX_DEBUG_AOA", "0") == "1": + print( + "[PADDLE_MINIMAX_DEBUG_AOA] " + f"cls={cls.__name__} num_hidden_layers={config.num_hidden_layers} " + f"layer_types={getattr(config, 'layer_types', None)}", + flush=True, + ) + model_prefix = "model." if cls != cls.base_model_class else "" + + aoa_statements = [ + f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight", + f"model.norm.weight -> {model_prefix}norm.weight", + ] + + # Layer-wise mappings + for layer_idx in range(config.num_hidden_layers): + layer_type = config.layer_types[layer_idx] + prefix = f"model.layers.{layer_idx}" + dst_prefix = f"{model_prefix}layers.{layer_idx}" + + # Layer norms (always present). + aoa_statements.extend( + [ + f"{prefix}.input_layernorm.weight -> {dst_prefix}.input_layernorm.weight", + f"{prefix}.post_attention_layernorm.weight -> {dst_prefix}.post_attention_layernorm.weight", + ] + ) + + # Attention weights (with transpose for Linear) + if layer_type == "full_attention": + aoa_statements.extend( + [ + f"{prefix}.self_attn.q_proj.weight^T -> {dst_prefix}.self_attn.q_proj.weight", + f"{prefix}.self_attn.k_proj.weight^T -> {dst_prefix}.self_attn.k_proj.weight", + f"{prefix}.self_attn.v_proj.weight^T -> {dst_prefix}.self_attn.v_proj.weight", + f"{prefix}.self_attn.o_proj.weight^T -> {dst_prefix}.self_attn.o_proj.weight", + ] + ) + elif layer_type == "linear_attention": + aoa_statements.extend( + [ + f"{prefix}.self_attn.qkv_proj.weight^T -> {dst_prefix}.self_attn.qkv_proj.weight", + f"{prefix}.self_attn.output_gate.weight^T -> {dst_prefix}.self_attn.output_gate.weight", + f"{prefix}.self_attn.out_proj.weight^T -> {dst_prefix}.self_attn.out_proj.weight", + f"{prefix}.self_attn.norm.weight -> {dst_prefix}.self_attn.norm.weight", + ] + ) + + aoa_statements.append( + f"{prefix}.block_sparse_moe.gate.weight^T -> {dst_prefix}.block_sparse_moe.gate.weight" + ) + for expert_idx in range(config.num_local_experts): + expert_prefix = f"{prefix}.block_sparse_moe.experts.{expert_idx}" + dst_expert_prefix = f"{dst_prefix}.block_sparse_moe.experts.{expert_idx}" + aoa_statements.extend( + [ + f"{expert_prefix}.w1.weight^T -> {dst_expert_prefix}.w1.weight", + f"{expert_prefix}.w2.weight^T -> {dst_expert_prefix}.w2.weight", + f"{expert_prefix}.w3.weight^T -> {dst_expert_prefix}.w3.weight", + ] + ) + + # lm_head + if cls != cls.base_model_class: + if config.tie_word_embeddings: + aoa_statements.append("model.embed_tokens.weight -> lm_head.weight") + else: + aoa_statements.append("lm_head.weight^T -> lm_head.weight") + + return {"aoa_statements": aoa_statements} + + @classmethod + def _gen_inv_aoa_config(cls, config: MiniMaxConfig): + """AOA config: mapping from PaddleFormers model key back to HF safetensors key.""" + model_prefix = "model." if cls != cls.base_model_class else "" + + aoa_statements = [ + f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight", + f"{model_prefix}norm.weight -> model.norm.weight", + ] + + for layer_idx in range(config.num_hidden_layers): + layer_type = config.layer_types[layer_idx] + prefix = f"model.layers.{layer_idx}" + dst_prefix = f"{model_prefix}layers.{layer_idx}" + + aoa_statements.extend( + [ + f"{dst_prefix}.input_layernorm.weight -> {prefix}.input_layernorm.weight", + f"{dst_prefix}.post_attention_layernorm.weight -> {prefix}.post_attention_layernorm.weight", + ] + ) + + if layer_type == "full_attention": + aoa_statements.extend( + [ + f"{dst_prefix}.self_attn.q_proj.weight^T -> {prefix}.self_attn.q_proj.weight", + f"{dst_prefix}.self_attn.k_proj.weight^T -> {prefix}.self_attn.k_proj.weight", + f"{dst_prefix}.self_attn.v_proj.weight^T -> {prefix}.self_attn.v_proj.weight", + f"{dst_prefix}.self_attn.o_proj.weight^T -> {prefix}.self_attn.o_proj.weight", + ] + ) + elif layer_type == "linear_attention": + aoa_statements.extend( + [ + f"{dst_prefix}.self_attn.qkv_proj.weight^T -> {prefix}.self_attn.qkv_proj.weight", + f"{dst_prefix}.self_attn.output_gate.weight^T -> {prefix}.self_attn.output_gate.weight", + f"{dst_prefix}.self_attn.out_proj.weight^T -> {prefix}.self_attn.out_proj.weight", + f"{dst_prefix}.self_attn.norm.weight -> {prefix}.self_attn.norm.weight", + ] + ) + + aoa_statements.append( + f"{dst_prefix}.block_sparse_moe.gate.weight^T -> {prefix}.block_sparse_moe.gate.weight" + ) + for expert_idx in range(config.num_local_experts): + expert_prefix = f"{prefix}.block_sparse_moe.experts.{expert_idx}" + dst_expert_prefix = f"{dst_prefix}.block_sparse_moe.experts.{expert_idx}" + aoa_statements.extend( + [ + f"{dst_expert_prefix}.w1.weight^T -> {expert_prefix}.w1.weight", + f"{dst_expert_prefix}.w2.weight^T -> {expert_prefix}.w2.weight", + f"{dst_expert_prefix}.w3.weight^T -> {expert_prefix}.w3.weight", + ] + ) + + if not config.tie_word_embeddings and cls != cls.base_model_class: + aoa_statements.append("lm_head.weight^T -> lm_head.weight") + + return {"aoa_statements": aoa_statements} + + +@register_base_model +class MiniMaxModel(MiniMaxPretrainedModel): + """The bare MiniMax (Text-01) decoder model.""" + + def __init__(self, config: MiniMaxConfig): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + + self.embed_tokens = GeneralEmbedding.create( + config=config, + num_embeddings=self.vocab_size, + embedding_dim=self.hidden_size, + padding_idx=self.padding_idx, + ) + self.layers = nn.LayerList( + [MiniMaxDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MiniMaxRotaryEmbedding(config=config) + + def forward( + self, + input_ids: paddle.Tensor | None = None, + attention_mask: paddle.Tensor | None = None, + position_ids: paddle.Tensor | None = None, + past_key_values: MiniMaxCache | None = None, + inputs_embeds: paddle.Tensor | None = None, + use_cache: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + attn_mask_startend_row_indices: paddle.Tensor | None = None, + **kwargs, + ): + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) and (inputs_embeds is None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if (input_ids is not None) and (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds (not both)") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype) + + bsz, seq_length, _ = inputs_embeds.shape + + if use_cache and past_key_values is None: + past_key_values = MiniMaxCache(config=self.config) + elif use_cache and not isinstance(past_key_values, MiniMaxCache): + raise ValueError( + f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}." + ) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = paddle.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=paddle.int64 + ).unsqueeze(0) + position_ids = position_ids.expand([bsz, -1]) + + mask_kwargs = { + "config": self.config, + "inputs_embeds": inputs_embeds, + "batch_size": bsz, + "seq_length": seq_length, + "cache_length": past_key_values.get_seq_length() if past_key_values is not None else 0, + "attention_mask": attention_mask, + "attn_mask_startend_row_indices": attn_mask_startend_row_indices, + "prepare_decoder_attention_mask": self._prepare_decoder_attention_mask, + } + causal_mask, attn_mask_startend_row_indices = create_causal_mask_and_row_indices(**mask_kwargs) + position_embeddings = self.rotary_emb(inputs_embeds, position_ids) + + hidden_states = inputs_embeds + all_hidden_states = [] if output_hidden_states else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states.append(hidden_states) + if self.config.layer_types[idx] == "full_attention": + input_attention_mask = causal_mask + input_mask_startend = attn_mask_startend_row_indices + else: + input_attention_mask = attention_mask + input_mask_startend = None + + hidden_states = decoder_layer( + hidden_states, + attention_mask=input_attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + past_key_values=past_key_values, + use_cache=use_cache, + attn_mask_startend_row_indices=input_mask_startend, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + if not return_dict: + outputs = (hidden_states,) + if output_hidden_states: + outputs = outputs + (tuple(all_hidden_states) if all_hidden_states else None,) + if use_cache: + outputs = outputs + (past_key_values,) + return outputs + + return MoEModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, + ) + + +class MiniMaxForCausalLM(MiniMaxPretrainedModel): + """MiniMax (Text-01) model with a language modeling head.""" + + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config: MiniMaxConfig): + super().__init__(config) + self.config = config + self.model = MiniMaxModel(config) + self.lm_head = GeneralLMHead(config) + self.criterion = CriterionLayer(config) + self.router_weights: list[paddle.Tensor] = [] + self.tie_weights() + + def forward( + self, + input_ids: paddle.Tensor, + position_ids: paddle.Tensor | None = None, + attention_mask: paddle.Tensor | None = None, + attn_mask_startend_row_indices: paddle.Tensor | None = None, + inputs_embeds: paddle.Tensor | None = None, + labels: paddle.Tensor | None = None, + loss_mask: paddle.Tensor | None = None, + use_cache: bool = False, + past_key_values: Cache | None = None, + output_hidden_states: bool | None = False, + output_router_logits: bool | None = None, + return_dict: bool = False, + **kwargs, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + if attention_mask is not None and attention_mask.dtype != paddle.bool: + attention_mask = paddle.cast(attention_mask, paddle.bool) + + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + attention_mask = None + + outputs = self.model( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss, _ = self.criterion(logits, labels) + + aux_loss = None + if output_router_logits: + aux_loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoECausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +class MiniMaxForCausalLMPipe(GeneralModelForCausalLMPipe): + """Pipeline-parallel wrapper for MiniMax (Text-01).""" + + config_class = MiniMaxConfig + _decoder_layer_cls = MiniMaxDecoderLayer + _get_tensor_parallel_mappings = MiniMaxModel._get_tensor_parallel_mappings + _init_weights = MiniMaxModel._init_weights + _keep_in_fp32_modules = MiniMaxModel._keep_in_fp32_modules + _tied_weights_keys = ["lm_head.weight"] + transpose_weight_keys = MiniMaxModel.transpose_weight_keys + _gen_aoa_config = MiniMaxForCausalLM._gen_aoa_config + _gen_inv_aoa_config = MiniMaxForCausalLM._gen_inv_aoa_config + + +__all__ = [ + "MiniMaxConfig", + "MiniMaxPretrainedModel", + "MiniMaxModel", + "MiniMaxForCausalLM", + "MiniMaxForCausalLMPipe", +] diff --git a/tests/transformers/minimax/__init__.py b/tests/transformers/minimax/__init__.py new file mode 100644 index 00000000000..290f972cf31 --- /dev/null +++ b/tests/transformers/minimax/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/minimax/test_modeling.py b/tests/transformers/minimax/test_modeling.py new file mode 100644 index 00000000000..f3ef22b513e --- /dev/null +++ b/tests/transformers/minimax/test_modeling.py @@ -0,0 +1,170 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import unittest + +import paddle + +from paddleformers.transformers import MiniMaxConfig, MiniMaxForCausalLM, MiniMaxModel +from paddleformers.transformers.auto.modeling import AutoModelForCausalLM +from tests.testing_utils import gpu_device_initializer +from tests.transformers.test_configuration_common import ConfigTester +from tests.transformers.test_modeling_common import ( + ModelTesterMixin, + ids_tensor, + random_attention_mask, +) + + +class MiniMaxModelTester: + def __init__( + self, + parent, + batch_size=2, + seq_length=7, + is_training=True, + use_input_mask=True, + vocab_size=99, + hidden_size=32, + intermediate_size=37, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + num_local_experts=4, + num_experts_per_tok=2, + block_size=4, + rms_norm_eps=1e-5, + initializer_range=0.02, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ): + self.parent: MiniMaxModelTest = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.block_size = block_size + self.rms_norm_eps = rms_norm_eps + self.initializer_range = initializer_range + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + config = self.get_config() + return config, input_ids, input_mask + + def get_config(self) -> MiniMaxConfig: + return MiniMaxConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + head_dim=self.head_dim, + layer_types=["full_attention", "linear_attention"], + block_size=self.block_size, + num_local_experts=self.num_local_experts, + num_experts_per_tok=self.num_experts_per_tok, + rms_norm_eps=self.rms_norm_eps, + initializer_range=self.initializer_range, + use_cache=False, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + ) + + def create_and_check_model(self, config: MiniMaxConfig, input_ids, input_mask): + model = MiniMaxModel(config) + model.eval() + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + + def create_and_check_for_causal_lm(self, config: MiniMaxConfig, input_ids, input_mask): + model = MiniMaxForCausalLM(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=input_ids, return_dict=True) + self.parent.assertEqual(result.logits.shape, [self.batch_size, self.seq_length, self.vocab_size]) + self.parent.assertIsNotNone(result.loss) + + def create_and_check_training_step(self, config: MiniMaxConfig, input_ids, input_mask): + model = MiniMaxForCausalLM(config) + model.train() + result = model(input_ids, attention_mask=input_mask, labels=input_ids, return_dict=True) + result.loss.backward() + self.parent.assertEqual(result.logits.shape, [self.batch_size, self.seq_length, self.vocab_size]) + self.parent.assertIsNotNone(model.model.embed_tokens.weight.grad) + + def create_and_check_auto_model(self, config: MiniMaxConfig): + model = AutoModelForCausalLM.from_config(config) + self.parent.assertIsInstance(model, MiniMaxForCausalLM) + + def prepare_config_and_inputs_for_common(self): + config, input_ids, input_mask = self.prepare_config_and_inputs() + return config, {"input_ids": input_ids, "attention_mask": input_mask} + + +class MiniMaxModelTest(ModelTesterMixin, unittest.TestCase): + base_model_class = MiniMaxModel + return_dict = False + use_labels = False + use_test_model_name_list = False + + all_model_classes = (MiniMaxModel, MiniMaxForCausalLM) + all_generative_model_classes = {MiniMaxForCausalLM: (MiniMaxModel, "minimax")} + + @gpu_device_initializer(log_prefix="MiniMaxModelTest") + def setUp(self): + super().setUp() + + self.model_tester = MiniMaxModelTester(self) + self.config_tester = ConfigTester(self, config_class=MiniMaxConfig, vocab_size=256, hidden_size=24) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_model_training_step(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_training_step(*config_and_inputs) + + def test_auto_model_for_causal_lm(self): + config = self.model_tester.get_config() + self.model_tester.create_and_check_auto_model(config)