From 7219834783a57e7d83c05515b0de5d0cd374debb Mon Sep 17 00:00:00 2001 From: Yichen Zhang Date: Wed, 10 Jun 2026 17:45:59 +0800 Subject: [PATCH] add callback for fp8 sonicmoe --- paddleformers/trainer/trainer.py | 5 +- paddleformers/trainer/trainer_callback.py | 59 +++++++++++++++++++ .../transformers/configuration_utils.py | 6 ++ 3 files changed, 67 insertions(+), 3 deletions(-) diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 5ea23988c3e..dfd670af557 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -171,10 +171,10 @@ CallbackHandler, DefaultFlowCallback, EMAStateAssemblerCallback, - InterleaveGateUpCallback, InternalMedicineCallback, PrinterCallback, ProgressCallback, + SonicMoELayoutSwitchCallback, SPGradSyncCallback, TrainerCallback, TrainerControl, @@ -1846,8 +1846,7 @@ def train( self.add_non_zcc_ema_callback(resume_from_checkpoint, ema_state_assembler) if self.args.using_sonic_moe: - callback = InterleaveGateUpCallback(self.model, resume_from_checkpoint, self.args.output_dir) - self.add_callback(callback) + self.add_callback(SonicMoELayoutSwitchCallback()) self.log_trainable_numel(model) diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 75fe6120c2a..3a5abe8a2dc 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -42,6 +42,7 @@ # Conditionally import paddlefleet modules if is_paddlefleet_available(): from paddlefleet.models.gpt import GPTModel + from paddlefleet.transformer.moe.moe_expert import SonicMoEExpert from paddlefleet.transformer.moe.moe_layer import MoELayer from paddlefleet.transformer.moe.moe_router import StandardMoERouter else: @@ -49,6 +50,9 @@ class GPTModel: pass + class SonicMoEExpert: + pass + class MoELayer: pass @@ -81,6 +85,7 @@ class StandardMoERouter: "SPGradSyncCallback", "EMAStateAssemblerCallback", "InternalMedicineCallback", + "SonicMoELayoutSwitchCallback", ] @@ -716,6 +721,8 @@ def on_step_begin(self, args, state, control, **kwargs): """ Quantize expert weights to FP8 before each training step """ + if args.using_sonic_moe: + return model = kwargs["model"] optimizer = kwargs["optimizer"] global skip_count @@ -755,6 +762,8 @@ def on_optimizer_begin(self, args, state, control, **kwargs): """ Reload weights before optimizer step """ + if args.using_sonic_moe: + return model = kwargs["model"] optimizer = kwargs["optimizer"] global skip_count @@ -1000,6 +1009,56 @@ def on_step_end(self, args, state, control, **kwargs): logger.info(f"[EMAStateAssembler] Assembling EMA state took {duration:.3f} seconds.") +class SonicMoELayoutSwitchCallback(TrainerCallback): + def _apply_to_sonic_moe_experts(self, model, fn_name): + def apply_layout_switch(layer): + if isinstance(layer, SonicMoEExpert): + getattr(layer, fn_name)() + + model.apply(apply_layout_switch) + + def _prepare_sonic_moe_fp8_weights(self, model, ensure_grouped_for_master=False): + def prepare_fp8_weights(layer): + if isinstance(layer, SonicMoEExpert): + layer.convert_weights_to_sonic_layout() + layer.quant_weight() + if ensure_grouped_for_master: + layer.convert_weights_to_grouped_layout() + + model.apply(prepare_fp8_weights) + + def _optimizer_has_expert_master(self, optimizer): + if not hasattr(self, "_cached_expert_param_name"): + self._cached_expert_param_name = None + for param in optimizer._inner_opt._parameter_list: + color = getattr(param, "color", -1) + if isinstance(color, dict) and color.get("color") == "moe_expert": + self._cached_expert_param_name = param.name + break + return ( + self._cached_expert_param_name is not None + and hasattr(optimizer, "_master_weights") + and self._cached_expert_param_name in optimizer._master_weights + ) + + def on_step_begin(self, args, state, control, **kwargs): + if args.using_sonic_moe: + model = kwargs["model"] + optimizer = kwargs["optimizer"] + if args.fp8: + need_master = not self._optimizer_has_expert_master(optimizer) + self._prepare_sonic_moe_fp8_weights(model, ensure_grouped_for_master=need_master) + optimizer.clear_param_storage("moe_expert") + else: + self._apply_to_sonic_moe_experts(model, "convert_weights_to_sonic_layout") + + def on_optimizer_begin(self, args, state, control, **kwargs): + if args.using_sonic_moe: + if args.fp8: + self._apply_to_sonic_moe_experts(kwargs["model"], "clear_fp8_weights") + self._apply_to_sonic_moe_experts(kwargs["model"], "convert_weights_to_grouped_layout") + + class InterleaveGateUpCallback(TrainerCallback): def __init__(self, model, resume_from_checkpoint=None, output_dir=None): self.model = model diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index 2558478b34d..09741739f2b 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -445,6 +445,12 @@ class LlmMetaConfig: True, "Whether to use FP8 for gradient storage during training (only effective if `fp8=True`). Further reduces memory footprint but may introduce minor numerical error. Defaults to False.", ), + ( + "use_ue8m0", + bool, + False, + "Whether to use UE8M0 packed scaling factors for FP8 on Blackwell GPUs (SM100+). Enables deep_gemm backend for weight gradient computation. Defaults to False.", + ), ] model_conf = [