diff --git a/README.md b/README.md
index 76e7ea002a5..f610ba184bf 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform
结合业界主流优化方法与飞桨在业务实践中积累的高效特性,PaddleFormers 致力于打造**高性能、低资源占用**的训练体验,帮助用户高效便捷地完成大模型训练,而无需关注底层复杂的优化细节。
## 🆕最新更新
-* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先Megatron-LM 6%。
+* 2026.03.31 - PaddleFormers v1.1 正式发布!在这个版本中我们支持了 GLM-4.5 系列模型的单步与多步 MTP 训练能力。依托 MTP 架构优势,开发者可显著提升推理效率;同时针对 MTP 模块训练场景,我们新增主干网络冻结开关,灵活满足各类模型精细化调优需求。此外,我们对视觉理解类模型进行了深度优化,Qwen3-VL 30B-A3B 模型性能相比上个版本提升48%,领先 Megatron-LM 6%。
* 2026.01.21 - PaddleFomers v1.0版本发布啦!我们提供了针对 LLM 和 VLM 等模型的训练能力,针对 DeepSeek-V3模型和 GLM-4.5-Air 等重点模型,我们实现了极致性能优化(训练性能明显超越 Megatron-LM )。针对 PaddleOCR-VL,我们在昆仑芯 P800、天数天垓150等国产计算芯片上进行了适配,更好的满足国内用户需求。
## ✨特性
@@ -50,7 +50,7 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform
- | LLM |
+ LLM |
DeepSeekv3 |
deepseek-ai/DeepSeek-V3-Base、deepseek-ai/DeepSeek-V3、deepseek-ai/DeepSeek-V3-0324 |
deepseek3 |
@@ -80,6 +80,11 @@ PaddleFormers 是基于百度深度学习框架 PaddlePaddle 搭建的 Transform
meta-llama/Meta-Llama-3-8B、meta-llama/Meta-Llama-3-8B-Instruct、meta-llama/Meta-Llama-3-70B、meta-llama/Meta-Llama-3-70B-Instruct、meta-llama/Llama-3.1-8B、meta-llama/Llama-3.1-8B-Instruct、meta-llama/Llama-3.1-70B、meta-llama/Llama-3.1-70B-Instruct、meta-llama/Llama-3.1-405B、meta-llama/Llama-3.1-405B-Instruct、meta-llama/Llama-3.2-1B、meta-llama/Llama-3.2-1B-Instruct、meta-llama/Llama-3.2-3B、meta-llama/Llama-3.2-3B-Instruct、meta-llama/Llama-3.3-70B-Instruct |
llama3 |
+
+ | MiniMax-Text-01 |
+ MiniMaxAI/MiniMax-Text-01 |
+ minimax |
+
| phi-4 |
microsoft/phi-4 |
diff --git a/docs/zh/model_capability.md b/docs/zh/model_capability.md
index 3b402f9d4b2..7e2b0b21374 100644
--- a/docs/zh/model_capability.md
+++ b/docs/zh/model_capability.md
@@ -7,6 +7,7 @@
|GLM-4.5|✓|✓|✓|✓|✓|
|GPT-OSS|✓|✓|✓|x|x|
|LLaMA3|✓|✓|✓|✓|✓|
+|MiniMax-Text-01|✓|✓|✓|x|x|
|Phi4|✓|✓|✓|✓|✓|
|Qwen2|✓|✓|✓|✓|✓|
|Qwen3|✓|✓|✓|✓|✓|
@@ -25,6 +26,7 @@
|GLM-4.5|✓|✓|✓|✓|✓|✓|
|GPT-OSS|✓|✓|x|x|✓|✓|
|LLaMA3|✓|✓|-|x|✓|✓|
+|MiniMax-Text-01|x|x|x|x|✓|✓|
|Phi4|✓|✓|-|x|✓|✓|
|Qwen2|✓|✓|x|x|✓|✓|
|Qwen3|✓|✓|✓|✓|✓|✓|
diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py
index 6d0d42cd39c..5e4f4f4d546 100644
--- a/paddleformers/datasets/template/template.py
+++ b/paddleformers/datasets/template/template.py
@@ -638,6 +638,22 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template":
thought_words=("\n", "\n\n\n"),
)
+register_template(
+ name="minimax",
+ format_user=StringFormatter(
+ slots=[
+ "user name=user\n{{content}}\n"
+ "ai name=assistant\n"
+ ]
+ ),
+ format_assistant=StringFormatter(slots=["{{content}}"]),
+ format_system=StringFormatter(
+ slots=["system ai_setting=assistant\n{{content}}\n"]
+ ),
+ chat_sep="\n",
+ suffix=[""],
+)
+
register_template(
name="paddleocr_vl",
format_user=StringFormatter(slots=["User: {{content}}\nAssistant: "]),
diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py
index ae1c4f20c73..068135cff2b 100644
--- a/paddleformers/transformers/__init__.py
+++ b/paddleformers/transformers/__init__.py
@@ -313,6 +313,8 @@
"minimax_m2": ["MiniMaxM2ForCausalLMPipe", "MiniMaxM2ForCausalLM"],
"deepseek_v4.configuration": ["DeepseekV4Config"],
"deepseek_v4": ["DeepseekV4ForCausalLMPipe", "DeepseekV4ForCausalLM"],
+ "minimax.configuration": ["MiniMaxConfig"],
+ "minimax": ["MiniMaxModel", "MiniMaxForCausalLM", "MiniMaxForCausalLMPipe"],
"glm4v_moe.image_processor": ["Glm4vImageProcessor"],
"glm4v_moe.image_processor_fast": ["Glm4vImageProcessorFast"],
"auto": ["AutoModelForCausalLM"],
@@ -414,6 +416,7 @@
from .glm_moe_dsa import *
from .minimax_m2 import *
from .deepseek_v4 import *
+ from .minimax import *
from .gpt_oss import *
from .phi3 import *
from .gemma3_text import *
diff --git a/paddleformers/transformers/auto/configuration.py b/paddleformers/transformers/auto/configuration.py
index c04e1f34a5a..296e3eb6757 100644
--- a/paddleformers/transformers/auto/configuration.py
+++ b/paddleformers/transformers/auto/configuration.py
@@ -54,6 +54,7 @@
("qwen3_vl_moe_text", "Qwen3VLMoeTextConfig"),
("glm4_moe", "Glm4MoeConfig"),
("glm_moe_dsa", "GlmMoeDsaConfig"),
+ ("minimax", "MiniMaxConfig"),
("minimax_m2", "MiniMaxM2Config"),
("deepseek_v4", "DeepseekV4Config"),
("gpt_oss", "GptOssConfig"),
@@ -89,6 +90,7 @@
("qwen3_vl_moe", "Qwen3VLMoe"),
("qwen3_vl_moe_text", "Qwen3VLMoeText"),
("glm_ocr", "GlmOcrForConditionalGeneration"),
+ ("minimax", "MiniMaxForCausalLM"),
("qwen3_5_moe", "Qwen3_5MoEForConditionalGeneration"),
("qwen3_5", "Qwen3_5ForConditionalGeneration"),
]
diff --git a/paddleformers/transformers/auto/modeling.py b/paddleformers/transformers/auto/modeling.py
index 11321baba1f..ea5a350fbdf 100644
--- a/paddleformers/transformers/auto/modeling.py
+++ b/paddleformers/transformers/auto/modeling.py
@@ -73,6 +73,7 @@
("Qwen3_5", "qwen3_5"),
("Glm4Moe", "glm4_moe"),
("GlmMoeDsa", "glm_moe_dsa"),
+ ("MiniMax", "minimax"),
("MiniMaxM2", "minimax_m2"),
("DeepseekV4", "deepseek_v4"),
("GptOss", "gpt_oss"),
diff --git a/paddleformers/transformers/minimax/__init__.py b/paddleformers/transformers/minimax/__init__.py
new file mode 100644
index 00000000000..0e968840f32
--- /dev/null
+++ b/paddleformers/transformers/minimax/__init__.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Package"""
+import sys
+from typing import TYPE_CHECKING
+
+from ...utils.lazy_import import _LazyModule
+
+import_structure = {
+ "configuration": ["MiniMaxConfig"],
+ "modeling": [
+ "MiniMaxPretrainedModel",
+ "MiniMaxModel",
+ "MiniMaxForCausalLM",
+ "MiniMaxForCausalLMPipe",
+ ],
+}
+
+if TYPE_CHECKING:
+ from .configuration import *
+ from .modeling import *
+else:
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ import_structure,
+ module_spec=__spec__,
+ )
diff --git a/paddleformers/transformers/minimax/configuration.py b/paddleformers/transformers/minimax/configuration.py
new file mode 100644
index 00000000000..70697549671
--- /dev/null
+++ b/paddleformers/transformers/minimax/configuration.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" MiniMax (Text-01) model configuration"""
+
+from ..configuration_utils import PretrainedConfig
+from ..modeling_rope_utils import rope_config_validation, standardize_rope_params
+
+
+class MiniMaxConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate a
+ MiniMax (Text-01) model according to the specified arguments, defining the model architecture.
+
+ MiniMax (Text-01) uses a hybrid attention mechanism:
+ - Some layers are full attention (standard causal self-attention with RoPE)
+ - Some layers are linear attention ("lightning attention") with intra-/inter-block attention
+
+ The layer type is controlled by `layer_types`. By default, the pattern is alternating
+ `full_attention` and `linear_attention` (full on odd-indexed layers, linear on even).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 200064):
+ Vocabulary size of the MiniMax model.
+ hidden_size (`int`, *optional*, defaults to 6144):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 9216):
+ Dimension of the routed expert MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 80):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 64):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ Number of key_value heads for implementing Grouped Query Attention.
+ head_dim (`int`, *optional*):
+ Dimension of each attention head. If None, defaults to hidden_size // num_attention_heads.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 10240000):
+ The maximum sequence length that this model might ever be used with.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether the model's input and output word embeddings should be tied.
+ rope_theta (`float`, *optional*, defaults to 10000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
+ Number of selected experts per token.
+ num_local_experts (`int`, *optional*, defaults to 32):
+ Number of routed experts.
+ output_router_logits (`bool`, *optional*, defaults to `False`):
+ Whether the router logits should be returned by the model.
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
+ Coefficient for the auxiliary load-balancing loss.
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
+ Jitter noise for the router.
+ layer_types (`list[str]`, *optional*):
+ A list that maps each layer index to its attention type. Can be `"full_attention"` or `"linear_attention"`.
+ block_size (`int`, *optional*, defaults to 256):
+ The length of each attention block for the lightning attention.
+ full_attn_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after full attention.
+ full_attn_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after full attention.
+ linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after lightning attention.
+ linear_attn_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after lightning attention.
+ mlp_alpha_factor (`float`, *optional*, defaults to 1):
+ Weight for residual value in residual connection after MLP.
+ mlp_beta_factor (`float`, *optional*, defaults to 1):
+ Weight for hidden state value in residual connection after MLP.
+
+ ```python
+ >>> from paddleformers.transformers import MiniMaxModel, MiniMaxConfig
+
+ >>> # Initializing a MiniMax (Text-01) style configuration
+ >>> configuration = MiniMaxConfig()
+
+ >>> # Initializing a model from the MiniMax (Text-01) style configuration
+ >>> model = MiniMaxModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "minimax"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=200064,
+ hidden_size=6144,
+ intermediate_size=9216,
+ num_hidden_layers=80,
+ num_attention_heads=64,
+ num_key_value_heads=8,
+ head_dim=None,
+ hidden_act="silu",
+ max_position_embeddings=10240000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=None,
+ eos_token_id=None,
+ tie_word_embeddings=False,
+ sliding_window=None,
+ attention_dropout=0.0,
+ num_experts_per_tok=2,
+ num_local_experts=32,
+ output_router_logits=False,
+ router_aux_loss_coef=0.001,
+ router_jitter_noise=0.0,
+ attn_type_list=None,
+ rope_theta=10000000.0,
+ rope_scaling=None,
+ layer_types=None,
+ block_size=256,
+ full_attn_alpha_factor=1.0,
+ full_attn_beta_factor=1.0,
+ linear_attn_alpha_factor=1.0,
+ linear_attn_beta_factor=1.0,
+ mlp_alpha_factor=1.0,
+ mlp_beta_factor=1.0,
+ **kwargs,
+ ):
+ full_attn_alpha_factor = kwargs.pop("layernorm_full_attention_alpha", full_attn_alpha_factor)
+ full_attn_beta_factor = kwargs.pop("layernorm_full_attention_beta", full_attn_beta_factor)
+ linear_attn_alpha_factor = kwargs.pop("layernorm_linear_attention_alpha", linear_attn_alpha_factor)
+ linear_attn_beta_factor = kwargs.pop("layernorm_linear_attention_beta", linear_attn_beta_factor)
+ mlp_alpha_factor = kwargs.pop("layernorm_mlp_alpha", mlp_alpha_factor)
+ mlp_beta_factor = kwargs.pop("layernorm_mlp_beta", mlp_beta_factor)
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.attention_dropout = attention_dropout
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
+ self.rope_parameters = self.rope_scaling
+ standardize_rope_params(self, rope_theta=rope_theta)
+ rope_config_validation(self)
+
+ self.num_experts_per_tok = num_experts_per_tok
+ self.num_local_experts = num_local_experts
+ self.output_router_logits = output_router_logits
+ self.router_aux_loss_coef = router_aux_loss_coef
+ self.router_jitter_noise = router_jitter_noise
+
+ if layer_types is None:
+ if attn_type_list is not None:
+ if len(attn_type_list) != num_hidden_layers:
+ raise ValueError(
+ f"attn_type_list length ({len(attn_type_list)}) must equal "
+ f"num_hidden_layers ({num_hidden_layers})."
+ )
+ self.layer_types = [
+ "linear_attention" if int(attn_type) == 0 else "full_attention" for attn_type in attn_type_list
+ ]
+ else:
+ self.layer_types = [
+ "full_attention" if bool((i + 1) % 2) else "linear_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+ else:
+ if len(layer_types) != num_hidden_layers:
+ raise ValueError(
+ f"layer_types length ({len(layer_types)}) must equal num_hidden_layers ({num_hidden_layers})."
+ )
+ self.layer_types = list(layer_types)
+ self.attn_type_list = [0 if layer_type == "linear_attention" else 1 for layer_type in self.layer_types]
+ self.block_size = block_size
+ self.full_attn_alpha_factor = full_attn_alpha_factor
+ self.full_attn_beta_factor = full_attn_beta_factor
+ self.linear_attn_alpha_factor = linear_attn_alpha_factor
+ self.linear_attn_beta_factor = linear_attn_beta_factor
+ self.mlp_alpha_factor = mlp_alpha_factor
+ self.mlp_beta_factor = mlp_beta_factor
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ sliding_window=sliding_window,
+ **kwargs,
+ )
+
+
+__all__ = ["MiniMaxConfig"]
diff --git a/paddleformers/transformers/minimax/modeling.py b/paddleformers/transformers/minimax/modeling.py
new file mode 100644
index 00000000000..504019a860f
--- /dev/null
+++ b/paddleformers/transformers/minimax/modeling.py
@@ -0,0 +1,1066 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Paddle MiniMax (Text-01) model."""
+
+from __future__ import annotations
+
+import os
+from typing import Callable
+
+import paddle
+import paddle.nn.functional as F
+from paddle import nn
+
+from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
+from ...nn.criterion.interface import CriterionLayer
+from ...nn.embedding import Embedding as GeneralEmbedding
+from ...nn.linear import Linear as GeneralLinear
+from ...nn.lm_head import LMHead as GeneralLMHead
+from ...nn.pp_model import GeneralModelForCausalLMPipe
+from ...utils.log import logger
+from ..cache_utils import Cache, DynamicCache
+from ..masking_utils import create_causal_mask_and_row_indices
+from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast
+from ..model_utils import PretrainedModel, register_base_model
+from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from .configuration import MiniMaxConfig
+
+
+def rotate_half(x: paddle.Tensor) -> paddle.Tensor:
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return paddle.cat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(
+ q: paddle.Tensor,
+ k: paddle.Tensor,
+ cos: paddle.Tensor,
+ sin: paddle.Tensor,
+ position_ids: paddle.Tensor | None = None,
+ unsqueeze_dim: int = 1,
+) -> tuple[paddle.Tensor, paddle.Tensor]:
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q: query tensor with shape [..., head_dim]
+ k: key tensor with shape [..., head_dim]
+ cos: cosine values
+ sin: sine values
+ unsqueeze_dim: dimension to unsqueeze cos/sin for broadcasting
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed.astype(q.dtype), k_embed.astype(k.dtype)
+
+
+class MiniMaxRMSNorm(nn.Layer):
+ """RMSNorm used in MiniMax (Text-01). Equivalent to T5LayerNorm."""
+
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
+ super().__init__()
+ self.weight = self.create_parameter(
+ shape=[hidden_size],
+ default_initializer=nn.initializer.Constant(1.0),
+ )
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.astype("float32")
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * paddle.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.astype(input_dtype)
+
+ def extra_repr(self) -> str:
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class MiniMaxCache(DynamicCache):
+ """Cache for MiniMax that supports both standard KV-cache and linear-attention (KV-statistic) cache."""
+
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__(config=config)
+ self.linear_cache: list[paddle.Tensor] = []
+
+ def set_linear_cache(self, layer_idx: int, linear_cache: paddle.Tensor):
+ for _ in range(len(self.linear_cache), layer_idx + 1):
+ self.linear_cache.append(None)
+ self.linear_cache[layer_idx] = linear_cache
+
+ def get_linear_cache(self, layer_idx: int):
+ if layer_idx < len(self.linear_cache):
+ return self.linear_cache[layer_idx]
+ return None
+
+ def __len__(self):
+ return max(super().__len__(), len(self.linear_cache))
+
+ def crop(self, max_length: int):
+ raise RuntimeError("MiniMaxCache does not support `crop` method")
+
+
+class MiniMaxLightningAttention(nn.Layer):
+ """Linear attention ("lightning attention") for MiniMax.
+
+ Operates on intra-block (within the current block) and inter-block (using cached KV statistics)
+ components. Statistics are computed by a gated linear unit (GLU) on a fused QKV projection.
+ """
+
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ self.num_attention_heads = config.num_attention_heads
+ self.num_hidden_layers = config.num_hidden_layers
+ self.block_size = config.block_size
+
+ hidden_size = config.hidden_size
+ qkv_out = self.num_attention_heads * self.head_dim * 3
+ attn_out = self.num_attention_heads * self.head_dim
+
+ self.qkv_proj = GeneralLinear.create(
+ hidden_size,
+ qkv_out,
+ has_bias=False,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.out_proj = GeneralLinear.create(
+ attn_out,
+ hidden_size,
+ has_bias=False,
+ config=config,
+ tp_plan="rowwise",
+ )
+ self.output_gate = GeneralLinear.create(
+ hidden_size,
+ attn_out,
+ has_bias=False,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.norm = MiniMaxRMSNorm(self.head_dim * self.num_attention_heads, eps=config.rms_norm_eps)
+
+ self.act_fn = F.silu
+
+ slope_rate = self.get_slope_rate()
+ query_decay, key_decay, diagonal_decay = self.decay_factors(slope_rate)
+
+ self.register_buffer("slope_rate", slope_rate, persistable=False)
+ self.register_buffer("query_decay", query_decay, persistable=False)
+ self.register_buffer("key_decay", key_decay, persistable=False)
+ self.register_buffer("diagonal_decay", diagonal_decay, persistable=False)
+
+ def get_slope_rate(self) -> paddle.Tensor:
+ base = 1.0 / (2.0 ** (8.0 / self.num_attention_heads))
+ dtype = paddle.get_default_dtype()
+ exponent = paddle.arange(self.num_attention_heads).astype(dtype) + 1
+ factor = 1.0 - self.layer_idx / (self.num_hidden_layers - 1 + 1e-5) + 1e-5
+
+ rate = paddle.pow(paddle.to_tensor(base, dtype=dtype), exponent)
+ rate = rate * factor
+ rate = rate.unsqueeze(-1).unsqueeze(-1)
+ return rate
+
+ def decay_factors(self, slope_rate: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
+ block_size_range = paddle.arange(self.block_size).astype(slope_rate.dtype) + 1
+
+ query_decay = paddle.exp(-slope_rate * block_size_range.unsqueeze(-1))
+ key_decay = paddle.exp(-slope_rate * (self.block_size - block_size_range.unsqueeze(-1)))
+
+ diff = block_size_range.unsqueeze(-1) - block_size_range.unsqueeze(0)
+ diff = diff.unsqueeze(0).unsqueeze(0)
+ decay = slope_rate * diff
+ neg_inf = paddle.full_like(decay, fill_value=float("-inf"))
+ decay = paddle.where(decay >= 0, -decay, neg_inf)
+ diagonal_decay = paddle.exp(decay)
+
+ return query_decay, key_decay, diagonal_decay
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None,
+ attention_mask: paddle.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> tuple[paddle.Tensor, paddle.Tensor | None]:
+ batch_size, seq_len, _ = hidden_states.shape
+ num_blocks = (seq_len + self.block_size - 1) // self.block_size
+ slope_rate = self.slope_rate.astype(hidden_states.dtype)
+
+ qkv_states = self.act_fn(self.qkv_proj(hidden_states))
+ qkv_states = qkv_states.reshape([batch_size, seq_len, self.num_attention_heads, 3 * self.head_dim])
+ query_states, key_states, value_states = paddle.split(qkv_states, num_or_sections=3, axis=-1)
+
+ query_states = query_states.transpose([0, 2, 1, 3])
+ key_states = key_states.transpose([0, 2, 1, 3])
+ value_states = value_states.transpose([0, 2, 1, 3])
+
+ attn_weights_inter = None
+ if past_key_values is not None and isinstance(past_key_values, MiniMaxCache):
+ attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx)
+
+ if attn_weights_inter is None:
+ attn_weights_inter = paddle.zeros(
+ [batch_size, self.num_attention_heads, self.head_dim, self.head_dim],
+ dtype=hidden_states.dtype,
+ )
+
+ if attention_mask is not None:
+ bool_mask = attention_mask.astype("bool")
+ expanded = bool_mask.unsqueeze(1).unsqueeze(-1)
+ value_states = paddle.where(expanded, value_states, paddle.zeros_like(value_states))
+
+ attn_output = []
+ for i in range(num_blocks):
+ start_idx = i * self.block_size
+ end_idx = min(start_idx + self.block_size, seq_len)
+ cur_bs = end_idx - start_idx
+
+ cur_q = query_states[:, :, start_idx:end_idx]
+ cur_k = key_states[:, :, start_idx:end_idx]
+ cur_v = value_states[:, :, start_idx:end_idx]
+
+ cur_qd = self.query_decay[:, :cur_bs].astype(cur_q.dtype)
+ cur_kd = self.key_decay[:, -cur_bs:].astype(cur_k.dtype)
+ cur_dd = self.diagonal_decay[:, :, :cur_bs, :cur_bs].astype(cur_q.dtype)
+ block_decay = paddle.exp(-slope_rate * cur_bs).astype(attn_weights_inter.dtype)
+
+ attn_intra = paddle.matmul(cur_q, cur_k.transpose([0, 1, 3, 2]))
+ attn_output_intra = paddle.matmul(attn_intra * cur_dd, cur_v)
+
+ attn_output_inter = paddle.matmul(cur_q * cur_qd, attn_weights_inter)
+
+ cur_out = attn_output_inter + attn_output_intra
+ attn_output.append(cur_out)
+
+ next_stat = paddle.matmul((cur_k * cur_kd).transpose([0, 1, 3, 2]), cur_v)
+ attn_weights_inter = attn_weights_inter * block_decay + next_stat
+ else:
+ ratio = paddle.exp(-slope_rate).astype(attn_weights_inter.dtype)
+ attn_output = []
+ for i in range(seq_len):
+ cur_q = query_states[:, :, i : i + 1]
+ cur_k = key_states[:, :, i : i + 1]
+ cur_v = value_states[:, :, i : i + 1]
+
+ cur_stat = paddle.matmul(cur_k.transpose([0, 1, 3, 2]), cur_v)
+ attn_weights_inter = ratio * attn_weights_inter + cur_stat
+ cur_out = paddle.matmul(cur_q, attn_weights_inter)
+ attn_output.append(cur_out)
+
+ attn_output = paddle.cat(attn_output, axis=-2)
+
+ attn_output = attn_output.transpose([0, 2, 1, 3])
+ attn_output = attn_output.reshape([batch_size, seq_len, self.num_attention_heads * self.head_dim])
+ attn_output = self.norm(attn_output)
+ attn_output = F.sigmoid(self.output_gate(hidden_states)).astype(attn_output.dtype) * attn_output
+ attn_output = attn_output.astype(hidden_states.dtype)
+ attn_output = self.out_proj(attn_output)
+
+ if past_key_values is not None and isinstance(past_key_values, MiniMaxCache):
+ past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter)
+
+ return attn_output, attn_weights_inter
+
+
+class MiniMaxRotaryEmbedding(nn.Layer):
+ inv_freq: paddle.Tensor # for `register_buffer` typing
+
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__()
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+ self.config = config
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+
+ if hasattr(config, "rope_parameters") and isinstance(config.rope_parameters, dict):
+ self.rope_type = config.rope_parameters.get("rope_type", "default")
+ else:
+ self.rope_type = "default"
+
+ rope_init_fn = self.compute_default_rope_parameters
+ if self.rope_type != "default":
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = rope_init_fn(self.config)
+
+ self.register_buffer("inv_freq", inv_freq, persistable=False)
+ self.original_inv_freq = inv_freq
+
+ @staticmethod
+ def compute_default_rope_parameters(
+ config: MiniMaxConfig | None = None,
+ seq_len: int | None = None,
+ ) -> tuple["paddle.Tensor", float]:
+ """Compute default RoPE inverse frequencies."""
+ base = config.rope_parameters["rope_theta"]
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ attention_factor = 1.0
+ inv_freq = 1.0 / (base ** (paddle.arange(0, dim, 2, dtype=paddle.int64).astype(dtype=paddle.float32) / dim))
+ return inv_freq, attention_factor
+
+ @dynamic_rope_update
+ def forward(self, x: paddle.Tensor, position_ids: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor]:
+ with paddle.amp.auto_cast(enable=False):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand([position_ids.shape[0], -1, 1]).to(x.dtype)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose([0, 2, 1])
+ emb = paddle.concat((freqs, freqs), axis=-1)
+
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.astype(x.dtype), sin.astype(x.dtype)
+
+
+def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor:
+ """Repeat KV heads n_rep times to match the number of query heads."""
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand([batch, num_key_value_heads, n_rep, slen, head_dim])
+ return hidden_states.reshape([batch, num_key_value_heads * n_rep, slen, head_dim])
+
+
+def eager_attention_forward(
+ module: nn.Layer,
+ query: paddle.Tensor,
+ key: paddle.Tensor,
+ value: paddle.Tensor,
+ attention_mask: paddle.Tensor | None,
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ """Standard (eager) attention forward implementation."""
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = paddle.matmul(query, key_states.transpose([0, 1, 3, 2])) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query.dtype)
+ attn_weights = attn_weights * (1.0 - dropout) if dropout > 0 else attn_weights
+
+ attn_output = paddle.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose([0, 2, 1, 3]).contiguous()
+ return attn_output, attn_weights
+
+
+class MiniMaxAttention(nn.Layer):
+ """Multi-headed attention (full attention) from MiniMax.
+
+ Includes Grouped Query Attention (GQA) when num_key_value_heads != num_attention_heads.
+ """
+
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+
+ self.num_attention_heads = config.num_attention_heads
+ self.num_key_value_heads = config.num_key_value_heads
+
+ self.q_proj = GeneralLinear.create(
+ config.hidden_size,
+ self.num_attention_heads * self.head_dim,
+ has_bias=False,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.k_proj = GeneralLinear.create(
+ config.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ has_bias=False,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.v_proj = GeneralLinear.create(
+ config.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ has_bias=False,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.o_proj = GeneralLinear.create(
+ self.num_attention_heads * self.head_dim,
+ config.hidden_size,
+ has_bias=False,
+ config=config,
+ tp_plan="rowwise",
+ )
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None,
+ attention_mask: paddle.Tensor | None = None,
+ past_key_values: Cache | None = None,
+ use_cache: bool = False,
+ attn_mask_startend_row_indices: paddle.Tensor | None = None,
+ **kwargs,
+ ) -> tuple[paddle.Tensor, paddle.Tensor | None]:
+ bsz, q_len, _ = hidden_states.shape
+ hidden_shape = [bsz, q_len, -1, self.head_dim]
+
+ query_states = self.q_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+ key_states = self.k_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+ value_states = self.v_proj(hidden_states).reshape(hidden_shape).transpose([0, 2, 1, 3])
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_values is not None:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
+
+ attention_interface: Callable = eager_attention_forward
+ if getattr(self.config, "_attn_implementation", "eager") != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query=query_states,
+ key=key_states,
+ value=value_states,
+ attention_mask=attention_mask,
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ )
+
+ attn_output = attn_output.reshape([bsz, q_len, -1]).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class MiniMaxTopKRouter(nn.Layer):
+ """Top-K router for MiniMax MoE."""
+
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_local_experts
+ self.hidden_dim = config.hidden_size
+ self.weight = self.create_parameter(
+ shape=[self.hidden_dim, self.num_experts],
+ dtype=paddle.get_default_dtype(),
+ is_bias=False,
+ )
+
+ def forward(self, hidden_states: paddle.Tensor) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
+ router_logits = F.linear(hidden_states, self.weight)
+ router_logits = F.softmax(router_logits.astype("float32"), axis=-1)
+ router_top_value, router_indices = paddle.topk(router_logits, self.top_k, axis=-1)
+ router_top_value = router_top_value / router_top_value.sum(axis=-1, keepdim=True)
+ router_top_value = router_top_value.astype(hidden_states.dtype)
+ return router_logits, router_top_value, router_indices
+
+
+class MiniMaxBlockSparseTop2MLP(nn.Layer):
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__()
+ self.hidden_dim = config.hidden_size
+ self.intermediate_dim = config.intermediate_size
+ self.w1 = GeneralLinear.create(
+ self.hidden_dim,
+ self.intermediate_dim,
+ has_bias=False,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.w2 = GeneralLinear.create(
+ self.intermediate_dim,
+ self.hidden_dim,
+ has_bias=False,
+ config=config,
+ tp_plan="rowwise",
+ )
+ self.w3 = GeneralLinear.create(
+ self.hidden_dim,
+ self.intermediate_dim,
+ has_bias=False,
+ config=config,
+ tp_plan="colwise",
+ )
+ self.act_fn = F.silu
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
+ return self.w2(hidden_states)
+
+
+class MiniMaxSparseMoeBlock(nn.Layer):
+ """Sparse MoE block (router + experts) for MiniMax."""
+
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.jitter_noise = config.router_jitter_noise
+ self.gate = MiniMaxTopKRouter(config)
+ self.experts = nn.LayerList([MiniMaxBlockSparseTop2MLP(config) for _ in range(config.num_local_experts)])
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
+ if self.training and self.jitter_noise > 0:
+ hidden_states = hidden_states * paddle.uniform(
+ hidden_states.shape,
+ dtype=hidden_states.dtype,
+ min=1.0 - self.jitter_noise,
+ max=1.0 + self.jitter_noise,
+ )
+ hidden_states_flat = hidden_states.reshape([-1, hidden_states.shape[-1]])
+ final_hidden_states = paddle.zeros_like(hidden_states_flat)
+ _, routing_weights, selected_experts = self.gate(hidden_states_flat)
+
+ with paddle.no_grad():
+ expert_mask = F.one_hot(selected_experts, num_classes=self.gate.num_experts)
+ expert_mask = expert_mask.transpose([2, 1, 0])
+
+ fixed_expert_order = os.environ.get("PADDLE_MINIMAX_MOE_FIXED_EXPERT_ORDER", "0") == "1"
+ connect_empty_experts = os.environ.get("PADDLE_MINIMAX_MOE_CONNECT_EMPTY_EXPERTS", "1") == "1"
+ expert_indices = (
+ range(self.gate.num_experts)
+ if fixed_expert_order
+ else [
+ int(expert_idx[0].item())
+ for expert_idx in paddle.greater(
+ expert_mask.sum(axis=(-1, -2)),
+ paddle.to_tensor(0, dtype="int64"),
+ ).nonzero()
+ ]
+ )
+
+ for expert_idx in expert_indices:
+ top_k_pos, token_idx = paddle.where(expert_mask[expert_idx])
+ current_state = hidden_states_flat[token_idx]
+ current_hidden_states = self.experts[expert_idx](current_state)
+ if token_idx.shape[0] > 0:
+ current_hidden_states = current_hidden_states * routing_weights[token_idx, top_k_pos, None]
+ final_hidden_states = final_hidden_states.index_add_(
+ axis=0,
+ index=token_idx,
+ value=current_hidden_states.astype(final_hidden_states.dtype),
+ )
+ elif fixed_expert_order and connect_empty_experts:
+ final_hidden_states = (
+ final_hidden_states + current_hidden_states.sum().astype(final_hidden_states.dtype) * 0.0
+ )
+
+ return final_hidden_states.reshape([batch_size, sequence_length, hidden_dim])
+
+
+class MiniMaxDecoderLayer(nn.Layer):
+ """A single decoder layer. Selects between full and linear attention by `config.layer_types[layer_idx]`."""
+
+ def __init__(self, config: MiniMaxConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.layer_idx = layer_idx
+ self.layer_type = config.layer_types[layer_idx]
+ self.mlp_alpha_factor = config.mlp_alpha_factor
+ self.mlp_beta_factor = config.mlp_beta_factor
+
+ self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ if self.layer_type == "linear_attention":
+ self.self_attn = MiniMaxLightningAttention(config, layer_idx)
+ self.attn_alpha_factor = config.linear_attn_alpha_factor
+ self.attn_beta_factor = config.linear_attn_beta_factor
+ elif self.layer_type == "full_attention":
+ self.self_attn = MiniMaxAttention(config, layer_idx)
+ self.attn_alpha_factor = config.full_attn_alpha_factor
+ self.attn_beta_factor = config.full_attn_beta_factor
+ else:
+ raise ValueError(
+ f"Unknown layer_type '{self.layer_type}'. Expected 'full_attention' or 'linear_attention'."
+ )
+
+ self.block_sparse_moe = MiniMaxSparseMoeBlock(config)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: paddle.Tensor | None = None,
+ position_ids: paddle.Tensor | None = None,
+ position_embeddings: tuple[paddle.Tensor, paddle.Tensor] | None = None,
+ past_key_values: Cache | None = None,
+ use_cache: bool | None = False,
+ attn_mask_startend_row_indices: paddle.Tensor | None = None,
+ **kwargs,
+ ) -> paddle.Tensor:
+ hidden_states = self.input_layernorm(hidden_states)
+ residual = hidden_states
+
+ attn_outputs = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+ **kwargs,
+ )
+ hidden_states = attn_outputs[0]
+ hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor
+
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ residual = hidden_states
+ hidden_states = self.block_sparse_moe(hidden_states)
+ hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor
+
+ return hidden_states
+
+
+def load_balancing_loss_func(
+ gate_logits: tuple[paddle.Tensor] | list[paddle.Tensor] | None,
+ num_experts: int | None = None,
+ top_k: int = 2,
+ attention_mask: paddle.Tensor | None = None,
+) -> paddle.Tensor | int:
+ """Auxiliary load-balancing loss for MoE."""
+ if gate_logits is None or not isinstance(gate_logits, (tuple, list)):
+ return 0
+
+ concatenated_gate_logits = paddle.cat([layer_gate for layer_gate in gate_logits], axis=0)
+
+ routing_weights = F.softmax(concatenated_gate_logits, axis=-1)
+ _, selected_experts = paddle.topk(routing_weights, top_k, axis=-1)
+ expert_mask = F.one_hot(selected_experts, num_experts)
+
+ if attention_mask is None:
+ tokens_per_expert = paddle.mean(expert_mask.astype("float32"), axis=0)
+ router_prob_per_expert = paddle.mean(routing_weights, axis=0)
+ else:
+ batch_size, sequence_length = attention_mask.shape
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
+ expert_attention_mask = (
+ attention_mask[None, :, :, None, None]
+ .expand([num_hidden_layers, batch_size, sequence_length, top_k, num_experts])
+ .reshape([-1, top_k, num_experts])
+ )
+ tokens_per_expert = paddle.sum(expert_mask.astype("float32") * expert_attention_mask, axis=0) / paddle.sum(
+ expert_attention_mask, axis=0
+ )
+ router_per_expert_attention_mask = (
+ attention_mask[None, :, :, None]
+ .expand([num_hidden_layers, batch_size, sequence_length, num_experts])
+ .reshape([-1, num_experts])
+ )
+ router_prob_per_expert = paddle.sum(routing_weights * router_per_expert_attention_mask, axis=0) / paddle.sum(
+ router_per_expert_attention_mask, axis=0
+ )
+
+ overall_loss = paddle.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
+ return overall_loss * num_experts
+
+
+class MiniMaxPretrainedModel(PretrainedModel):
+ config_class = MiniMaxConfig
+ config: MiniMaxConfig
+ base_model_prefix = "model"
+
+ transpose_weight_keys = [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_up_proj",
+ "down_proj",
+ "gate",
+ "w1",
+ "w2",
+ "w3",
+ "output_gate",
+ "qkv_proj",
+ "out_proj",
+ ]
+
+ @classmethod
+ def _gen_aoa_config(cls, config: MiniMaxConfig):
+ """AOA config: mapping from HF safetensors key to PaddleFormers model key."""
+ if os.environ.get("PADDLE_MINIMAX_DEBUG_AOA", "0") == "1":
+ print(
+ "[PADDLE_MINIMAX_DEBUG_AOA] "
+ f"cls={cls.__name__} num_hidden_layers={config.num_hidden_layers} "
+ f"layer_types={getattr(config, 'layer_types', None)}",
+ flush=True,
+ )
+ model_prefix = "model." if cls != cls.base_model_class else ""
+
+ aoa_statements = [
+ f"model.embed_tokens.weight -> {model_prefix}embed_tokens.weight",
+ f"model.norm.weight -> {model_prefix}norm.weight",
+ ]
+
+ # Layer-wise mappings
+ for layer_idx in range(config.num_hidden_layers):
+ layer_type = config.layer_types[layer_idx]
+ prefix = f"model.layers.{layer_idx}"
+ dst_prefix = f"{model_prefix}layers.{layer_idx}"
+
+ # Layer norms (always present).
+ aoa_statements.extend(
+ [
+ f"{prefix}.input_layernorm.weight -> {dst_prefix}.input_layernorm.weight",
+ f"{prefix}.post_attention_layernorm.weight -> {dst_prefix}.post_attention_layernorm.weight",
+ ]
+ )
+
+ # Attention weights (with transpose for Linear)
+ if layer_type == "full_attention":
+ aoa_statements.extend(
+ [
+ f"{prefix}.self_attn.q_proj.weight^T -> {dst_prefix}.self_attn.q_proj.weight",
+ f"{prefix}.self_attn.k_proj.weight^T -> {dst_prefix}.self_attn.k_proj.weight",
+ f"{prefix}.self_attn.v_proj.weight^T -> {dst_prefix}.self_attn.v_proj.weight",
+ f"{prefix}.self_attn.o_proj.weight^T -> {dst_prefix}.self_attn.o_proj.weight",
+ ]
+ )
+ elif layer_type == "linear_attention":
+ aoa_statements.extend(
+ [
+ f"{prefix}.self_attn.qkv_proj.weight^T -> {dst_prefix}.self_attn.qkv_proj.weight",
+ f"{prefix}.self_attn.output_gate.weight^T -> {dst_prefix}.self_attn.output_gate.weight",
+ f"{prefix}.self_attn.out_proj.weight^T -> {dst_prefix}.self_attn.out_proj.weight",
+ f"{prefix}.self_attn.norm.weight -> {dst_prefix}.self_attn.norm.weight",
+ ]
+ )
+
+ aoa_statements.append(
+ f"{prefix}.block_sparse_moe.gate.weight^T -> {dst_prefix}.block_sparse_moe.gate.weight"
+ )
+ for expert_idx in range(config.num_local_experts):
+ expert_prefix = f"{prefix}.block_sparse_moe.experts.{expert_idx}"
+ dst_expert_prefix = f"{dst_prefix}.block_sparse_moe.experts.{expert_idx}"
+ aoa_statements.extend(
+ [
+ f"{expert_prefix}.w1.weight^T -> {dst_expert_prefix}.w1.weight",
+ f"{expert_prefix}.w2.weight^T -> {dst_expert_prefix}.w2.weight",
+ f"{expert_prefix}.w3.weight^T -> {dst_expert_prefix}.w3.weight",
+ ]
+ )
+
+ # lm_head
+ if cls != cls.base_model_class:
+ if config.tie_word_embeddings:
+ aoa_statements.append("model.embed_tokens.weight -> lm_head.weight")
+ else:
+ aoa_statements.append("lm_head.weight^T -> lm_head.weight")
+
+ return {"aoa_statements": aoa_statements}
+
+ @classmethod
+ def _gen_inv_aoa_config(cls, config: MiniMaxConfig):
+ """AOA config: mapping from PaddleFormers model key back to HF safetensors key."""
+ model_prefix = "model." if cls != cls.base_model_class else ""
+
+ aoa_statements = [
+ f"{model_prefix}embed_tokens.weight -> model.embed_tokens.weight",
+ f"{model_prefix}norm.weight -> model.norm.weight",
+ ]
+
+ for layer_idx in range(config.num_hidden_layers):
+ layer_type = config.layer_types[layer_idx]
+ prefix = f"model.layers.{layer_idx}"
+ dst_prefix = f"{model_prefix}layers.{layer_idx}"
+
+ aoa_statements.extend(
+ [
+ f"{dst_prefix}.input_layernorm.weight -> {prefix}.input_layernorm.weight",
+ f"{dst_prefix}.post_attention_layernorm.weight -> {prefix}.post_attention_layernorm.weight",
+ ]
+ )
+
+ if layer_type == "full_attention":
+ aoa_statements.extend(
+ [
+ f"{dst_prefix}.self_attn.q_proj.weight^T -> {prefix}.self_attn.q_proj.weight",
+ f"{dst_prefix}.self_attn.k_proj.weight^T -> {prefix}.self_attn.k_proj.weight",
+ f"{dst_prefix}.self_attn.v_proj.weight^T -> {prefix}.self_attn.v_proj.weight",
+ f"{dst_prefix}.self_attn.o_proj.weight^T -> {prefix}.self_attn.o_proj.weight",
+ ]
+ )
+ elif layer_type == "linear_attention":
+ aoa_statements.extend(
+ [
+ f"{dst_prefix}.self_attn.qkv_proj.weight^T -> {prefix}.self_attn.qkv_proj.weight",
+ f"{dst_prefix}.self_attn.output_gate.weight^T -> {prefix}.self_attn.output_gate.weight",
+ f"{dst_prefix}.self_attn.out_proj.weight^T -> {prefix}.self_attn.out_proj.weight",
+ f"{dst_prefix}.self_attn.norm.weight -> {prefix}.self_attn.norm.weight",
+ ]
+ )
+
+ aoa_statements.append(
+ f"{dst_prefix}.block_sparse_moe.gate.weight^T -> {prefix}.block_sparse_moe.gate.weight"
+ )
+ for expert_idx in range(config.num_local_experts):
+ expert_prefix = f"{prefix}.block_sparse_moe.experts.{expert_idx}"
+ dst_expert_prefix = f"{dst_prefix}.block_sparse_moe.experts.{expert_idx}"
+ aoa_statements.extend(
+ [
+ f"{dst_expert_prefix}.w1.weight^T -> {expert_prefix}.w1.weight",
+ f"{dst_expert_prefix}.w2.weight^T -> {expert_prefix}.w2.weight",
+ f"{dst_expert_prefix}.w3.weight^T -> {expert_prefix}.w3.weight",
+ ]
+ )
+
+ if not config.tie_word_embeddings and cls != cls.base_model_class:
+ aoa_statements.append("lm_head.weight^T -> lm_head.weight")
+
+ return {"aoa_statements": aoa_statements}
+
+
+@register_base_model
+class MiniMaxModel(MiniMaxPretrainedModel):
+ """The bare MiniMax (Text-01) decoder model."""
+
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__(config)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.hidden_size = config.hidden_size
+
+ self.embed_tokens = GeneralEmbedding.create(
+ config=config,
+ num_embeddings=self.vocab_size,
+ embedding_dim=self.hidden_size,
+ padding_idx=self.padding_idx,
+ )
+ self.layers = nn.LayerList(
+ [MiniMaxDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = MiniMaxRotaryEmbedding(config=config)
+
+ def forward(
+ self,
+ input_ids: paddle.Tensor | None = None,
+ attention_mask: paddle.Tensor | None = None,
+ position_ids: paddle.Tensor | None = None,
+ past_key_values: MiniMaxCache | None = None,
+ inputs_embeds: paddle.Tensor | None = None,
+ use_cache: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ attn_mask_startend_row_indices: paddle.Tensor | None = None,
+ **kwargs,
+ ):
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) and (inputs_embeds is None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+ if (input_ids is not None) and (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds (not both)")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids).astype(self.embed_tokens.weight.dtype)
+
+ bsz, seq_length, _ = inputs_embeds.shape
+
+ if use_cache and past_key_values is None:
+ past_key_values = MiniMaxCache(config=self.config)
+ elif use_cache and not isinstance(past_key_values, MiniMaxCache):
+ raise ValueError(
+ f"MiniMax uses cache of its own and is not compatible with `past_key_values` of type {type(past_key_values)}."
+ )
+
+ if position_ids is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ position_ids = paddle.arange(
+ past_seen_tokens, seq_length + past_seen_tokens, dtype=paddle.int64
+ ).unsqueeze(0)
+ position_ids = position_ids.expand([bsz, -1])
+
+ mask_kwargs = {
+ "config": self.config,
+ "inputs_embeds": inputs_embeds,
+ "batch_size": bsz,
+ "seq_length": seq_length,
+ "cache_length": past_key_values.get_seq_length() if past_key_values is not None else 0,
+ "attention_mask": attention_mask,
+ "attn_mask_startend_row_indices": attn_mask_startend_row_indices,
+ "prepare_decoder_attention_mask": self._prepare_decoder_attention_mask,
+ }
+ causal_mask, attn_mask_startend_row_indices = create_causal_mask_and_row_indices(**mask_kwargs)
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
+
+ hidden_states = inputs_embeds
+ all_hidden_states = [] if output_hidden_states else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states.append(hidden_states)
+ if self.config.layer_types[idx] == "full_attention":
+ input_attention_mask = causal_mask
+ input_mask_startend = attn_mask_startend_row_indices
+ else:
+ input_attention_mask = attention_mask
+ input_mask_startend = None
+
+ hidden_states = decoder_layer(
+ hidden_states,
+ attention_mask=input_attention_mask,
+ position_ids=position_ids,
+ position_embeddings=position_embeddings,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ attn_mask_startend_row_indices=input_mask_startend,
+ **kwargs,
+ )
+
+ hidden_states = self.norm(hidden_states)
+ if output_hidden_states:
+ all_hidden_states.append(hidden_states)
+
+ if not return_dict:
+ outputs = (hidden_states,)
+ if output_hidden_states:
+ outputs = outputs + (tuple(all_hidden_states) if all_hidden_states else None,)
+ if use_cache:
+ outputs = outputs + (past_key_values,)
+ return outputs
+
+ return MoEModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values,
+ hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
+ )
+
+
+class MiniMaxForCausalLM(MiniMaxPretrainedModel):
+ """MiniMax (Text-01) model with a language modeling head."""
+
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+ def __init__(self, config: MiniMaxConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = MiniMaxModel(config)
+ self.lm_head = GeneralLMHead(config)
+ self.criterion = CriterionLayer(config)
+ self.router_weights: list[paddle.Tensor] = []
+ self.tie_weights()
+
+ def forward(
+ self,
+ input_ids: paddle.Tensor,
+ position_ids: paddle.Tensor | None = None,
+ attention_mask: paddle.Tensor | None = None,
+ attn_mask_startend_row_indices: paddle.Tensor | None = None,
+ inputs_embeds: paddle.Tensor | None = None,
+ labels: paddle.Tensor | None = None,
+ loss_mask: paddle.Tensor | None = None,
+ use_cache: bool = False,
+ past_key_values: Cache | None = None,
+ output_hidden_states: bool | None = False,
+ output_router_logits: bool | None = None,
+ return_dict: bool = False,
+ **kwargs,
+ ):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_router_logits = (
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
+ )
+
+ if attention_mask is not None and attention_mask.dtype != paddle.bool:
+ attention_mask = paddle.cast(attention_mask, paddle.bool)
+
+ if attn_mask_startend_row_indices is not None and attention_mask is not None:
+ logger.warning(
+ "You have provided both attn_mask_startend_row_indices and attention_mask. "
+ "The attn_mask_startend_row_indices will be used."
+ )
+ attention_mask = None
+
+ outputs = self.model(
+ input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ attn_mask_startend_row_indices=attn_mask_startend_row_indices,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss, _ = self.criterion(logits, labels)
+
+ aux_loss = None
+ if output_router_logits:
+ aux_loss = None
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return MoECausalLMOutputWithPast(
+ loss=loss,
+ aux_loss=aux_loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_logits,
+ )
+
+
+class MiniMaxForCausalLMPipe(GeneralModelForCausalLMPipe):
+ """Pipeline-parallel wrapper for MiniMax (Text-01)."""
+
+ config_class = MiniMaxConfig
+ _decoder_layer_cls = MiniMaxDecoderLayer
+ _get_tensor_parallel_mappings = MiniMaxModel._get_tensor_parallel_mappings
+ _init_weights = MiniMaxModel._init_weights
+ _keep_in_fp32_modules = MiniMaxModel._keep_in_fp32_modules
+ _tied_weights_keys = ["lm_head.weight"]
+ transpose_weight_keys = MiniMaxModel.transpose_weight_keys
+ _gen_aoa_config = MiniMaxForCausalLM._gen_aoa_config
+ _gen_inv_aoa_config = MiniMaxForCausalLM._gen_inv_aoa_config
+
+
+__all__ = [
+ "MiniMaxConfig",
+ "MiniMaxPretrainedModel",
+ "MiniMaxModel",
+ "MiniMaxForCausalLM",
+ "MiniMaxForCausalLMPipe",
+]
diff --git a/tests/transformers/minimax/__init__.py b/tests/transformers/minimax/__init__.py
new file mode 100644
index 00000000000..290f972cf31
--- /dev/null
+++ b/tests/transformers/minimax/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/transformers/minimax/test_modeling.py b/tests/transformers/minimax/test_modeling.py
new file mode 100644
index 00000000000..f3ef22b513e
--- /dev/null
+++ b/tests/transformers/minimax/test_modeling.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import unittest
+
+import paddle
+
+from paddleformers.transformers import MiniMaxConfig, MiniMaxForCausalLM, MiniMaxModel
+from paddleformers.transformers.auto.modeling import AutoModelForCausalLM
+from tests.testing_utils import gpu_device_initializer
+from tests.transformers.test_configuration_common import ConfigTester
+from tests.transformers.test_modeling_common import (
+ ModelTesterMixin,
+ ids_tensor,
+ random_attention_mask,
+)
+
+
+class MiniMaxModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=2,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ vocab_size=99,
+ hidden_size=32,
+ intermediate_size=37,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ num_key_value_heads=2,
+ head_dim=8,
+ num_local_experts=4,
+ num_experts_per_tok=2,
+ block_size=4,
+ rms_norm_eps=1e-5,
+ initializer_range=0.02,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ ):
+ self.parent: MiniMaxModelTest = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.head_dim = head_dim
+ self.num_local_experts = num_local_experts
+ self.num_experts_per_tok = num_experts_per_tok
+ self.block_size = block_size
+ self.rms_norm_eps = rms_norm_eps
+ self.initializer_range = initializer_range
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, dtype=paddle.int64)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ config = self.get_config()
+ return config, input_ids, input_mask
+
+ def get_config(self) -> MiniMaxConfig:
+ return MiniMaxConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ intermediate_size=self.intermediate_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ num_key_value_heads=self.num_key_value_heads,
+ head_dim=self.head_dim,
+ layer_types=["full_attention", "linear_attention"],
+ block_size=self.block_size,
+ num_local_experts=self.num_local_experts,
+ num_experts_per_tok=self.num_experts_per_tok,
+ rms_norm_eps=self.rms_norm_eps,
+ initializer_range=self.initializer_range,
+ use_cache=False,
+ pad_token_id=self.pad_token_id,
+ bos_token_id=self.bos_token_id,
+ eos_token_id=self.eos_token_id,
+ )
+
+ def create_and_check_model(self, config: MiniMaxConfig, input_ids, input_mask):
+ model = MiniMaxModel(config)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask)
+ self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size])
+
+ def create_and_check_for_causal_lm(self, config: MiniMaxConfig, input_ids, input_mask):
+ model = MiniMaxForCausalLM(config)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=input_ids, return_dict=True)
+ self.parent.assertEqual(result.logits.shape, [self.batch_size, self.seq_length, self.vocab_size])
+ self.parent.assertIsNotNone(result.loss)
+
+ def create_and_check_training_step(self, config: MiniMaxConfig, input_ids, input_mask):
+ model = MiniMaxForCausalLM(config)
+ model.train()
+ result = model(input_ids, attention_mask=input_mask, labels=input_ids, return_dict=True)
+ result.loss.backward()
+ self.parent.assertEqual(result.logits.shape, [self.batch_size, self.seq_length, self.vocab_size])
+ self.parent.assertIsNotNone(model.model.embed_tokens.weight.grad)
+
+ def create_and_check_auto_model(self, config: MiniMaxConfig):
+ model = AutoModelForCausalLM.from_config(config)
+ self.parent.assertIsInstance(model, MiniMaxForCausalLM)
+
+ def prepare_config_and_inputs_for_common(self):
+ config, input_ids, input_mask = self.prepare_config_and_inputs()
+ return config, {"input_ids": input_ids, "attention_mask": input_mask}
+
+
+class MiniMaxModelTest(ModelTesterMixin, unittest.TestCase):
+ base_model_class = MiniMaxModel
+ return_dict = False
+ use_labels = False
+ use_test_model_name_list = False
+
+ all_model_classes = (MiniMaxModel, MiniMaxForCausalLM)
+ all_generative_model_classes = {MiniMaxForCausalLM: (MiniMaxModel, "minimax")}
+
+ @gpu_device_initializer(log_prefix="MiniMaxModelTest")
+ def setUp(self):
+ super().setUp()
+
+ self.model_tester = MiniMaxModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=MiniMaxConfig, vocab_size=256, hidden_size=24)
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_model_causal_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
+
+ def test_model_training_step(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_training_step(*config_and_inputs)
+
+ def test_auto_model_for_causal_lm(self):
+ config = self.model_tester.get_config()
+ self.model_tester.create_and_check_auto_model(config)