Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@
CallbackHandler,
DefaultFlowCallback,
EMAStateAssemblerCallback,
InterleaveGateUpCallback,
InternalMedicineCallback,
PrinterCallback,
ProgressCallback,
SonicMoELayoutSwitchCallback,
SPGradSyncCallback,
TrainerCallback,
TrainerControl,
Expand Down Expand Up @@ -1846,8 +1846,7 @@ def train(
self.add_non_zcc_ema_callback(resume_from_checkpoint, ema_state_assembler)

if self.args.using_sonic_moe:
callback = InterleaveGateUpCallback(self.model, resume_from_checkpoint, self.args.output_dir)
self.add_callback(callback)
self.add_callback(SonicMoELayoutSwitchCallback())

self.log_trainable_numel(model)

Expand Down
59 changes: 59 additions & 0 deletions paddleformers/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
# Conditionally import paddlefleet modules
if is_paddlefleet_available():
from paddlefleet.models.gpt import GPTModel
from paddlefleet.transformer.moe.moe_expert import SonicMoEExpert
from paddlefleet.transformer.moe.moe_layer import MoELayer
from paddlefleet.transformer.moe.moe_router import StandardMoERouter
else:

class GPTModel:
pass

class SonicMoEExpert:
pass

class MoELayer:
pass

Expand Down Expand Up @@ -81,6 +85,7 @@ class StandardMoERouter:
"SPGradSyncCallback",
"EMAStateAssemblerCallback",
"InternalMedicineCallback",
"SonicMoELayoutSwitchCallback",
]


Expand Down Expand Up @@ -716,6 +721,8 @@ def on_step_begin(self, args, state, control, **kwargs):
"""
Quantize expert weights to FP8 before each training step
"""
if args.using_sonic_moe:
return
model = kwargs["model"]
optimizer = kwargs["optimizer"]
global skip_count
Expand Down Expand Up @@ -755,6 +762,8 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
"""
Reload weights before optimizer step
"""
if args.using_sonic_moe:
return
model = kwargs["model"]
optimizer = kwargs["optimizer"]
global skip_count
Expand Down Expand Up @@ -1000,6 +1009,56 @@ def on_step_end(self, args, state, control, **kwargs):
logger.info(f"[EMAStateAssembler] Assembling EMA state took {duration:.3f} seconds.")


class SonicMoELayoutSwitchCallback(TrainerCallback):
def _apply_to_sonic_moe_experts(self, model, fn_name):
def apply_layout_switch(layer):
if isinstance(layer, SonicMoEExpert):
getattr(layer, fn_name)()

model.apply(apply_layout_switch)

def _prepare_sonic_moe_fp8_weights(self, model, ensure_grouped_for_master=False):
def prepare_fp8_weights(layer):
if isinstance(layer, SonicMoEExpert):
layer.convert_weights_to_sonic_layout()
layer.quant_weight()
if ensure_grouped_for_master:
layer.convert_weights_to_grouped_layout()

model.apply(prepare_fp8_weights)

def _optimizer_has_expert_master(self, optimizer):
if not hasattr(self, "_cached_expert_param_name"):
self._cached_expert_param_name = None
for param in optimizer._inner_opt._parameter_list:
color = getattr(param, "color", -1)
if isinstance(color, dict) and color.get("color") == "moe_expert":
self._cached_expert_param_name = param.name
break
return (
self._cached_expert_param_name is not None
and hasattr(optimizer, "_master_weights")
and self._cached_expert_param_name in optimizer._master_weights
)

def on_step_begin(self, args, state, control, **kwargs):
if args.using_sonic_moe:
model = kwargs["model"]
optimizer = kwargs["optimizer"]
if args.fp8:
need_master = not self._optimizer_has_expert_master(optimizer)
self._prepare_sonic_moe_fp8_weights(model, ensure_grouped_for_master=need_master)
optimizer.clear_param_storage("moe_expert")
else:
self._apply_to_sonic_moe_experts(model, "convert_weights_to_sonic_layout")

def on_optimizer_begin(self, args, state, control, **kwargs):
if args.using_sonic_moe:
if args.fp8:
self._apply_to_sonic_moe_experts(kwargs["model"], "clear_fp8_weights")
self._apply_to_sonic_moe_experts(kwargs["model"], "convert_weights_to_grouped_layout")


class InterleaveGateUpCallback(TrainerCallback):
def __init__(self, model, resume_from_checkpoint=None, output_dir=None):
self.model = model
Expand Down
6 changes: 6 additions & 0 deletions paddleformers/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ class LlmMetaConfig:
True,
"Whether to use FP8 for gradient storage during training (only effective if `fp8=True`). Further reduces memory footprint but may introduce minor numerical error. Defaults to False.",
),
(
"use_ue8m0",
bool,
False,
"Whether to use UE8M0 packed scaling factors for FP8 on Blackwell GPUs (SM100+). Enables deep_gemm backend for weight gradient computation. Defaults to False.",
),
]

model_conf = [
Expand Down
Loading