Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@
from .trainer_callback import (
CallbackHandler,
DefaultFlowCallback,
InterleaveGateUpCallback,
InternalMedicineCallback,
PrinterCallback,
ProgressCallback,
SonicMoELayoutSwitchCallback,
SPGradSyncCallback,
TrainerCallback,
TrainerControl,
Expand Down Expand Up @@ -1806,8 +1806,10 @@ def train(
self.add_non_zcc_ema_callback(resume_from_checkpoint)

if self.args.using_sonic_moe:
callback = InterleaveGateUpCallback(self.model, resume_from_checkpoint, self.args.output_dir)
self.add_callback(callback)
# callback = InterleaveGateUpCallback(self.model, resume_from_checkpoint, self.args.output_dir)
# self.add_callback(callback)
print("==== add sonicmoe callback ====")
self.add_callback(SonicMoELayoutSwitchCallback())

self.log_trainable_numel(model)

Expand Down
132 changes: 113 additions & 19 deletions paddleformers/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
# Conditionally import paddlefleet modules
if is_paddlefleet_available():
from paddlefleet.models.gpt import GPTModel
from paddlefleet.transformer.moe.moe_expert import SonicMoEExpert
from paddlefleet.transformer.moe.moe_layer import MoELayer
from paddlefleet.transformer.moe.moe_router import StandardMoERouter
else:

class GPTModel:
pass

class SonicMoEExpert:
pass

class MoELayer:
pass

Expand Down Expand Up @@ -81,6 +85,7 @@ class StandardMoERouter:
"SPGradSyncCallback",
"EMAStateAssemblerCallback",
"InternalMedicineCallback",
"SonicMoELayoutSwitchCallback",
]


Expand Down Expand Up @@ -707,15 +712,64 @@ def enable_in_dict_config(config, key):
skip_count = 0


_FP8_STORAGE_COLORS = (
"moe_expert",
"rms_linear",
"memory_attn",
"attn_out_project",
"shared_expert",
)


def _clear_fp8_param_storage(optimizer, colors=None):
colors = _FP8_STORAGE_COLORS if colors is None else colors
for color in colors:
optimizer.clear_param_storage(color)


def _get_moe_expert_param_names(optimizer):
inner_opt = getattr(optimizer, "_inner_opt", optimizer)
parameters = getattr(inner_opt, "_parameter_list", ())
names = []
for param in parameters:
color = getattr(param, "color", -1)
if isinstance(color, dict) and color.get("color") == "moe_expert":
names.append(param.name)
return names


def _offload_moe_expert_master_weights(optimizer):
master_weights = getattr(optimizer, "_master_weights", {})
moe_weights_name = _get_moe_expert_param_names(optimizer)
for name in moe_weights_name:
# NOTE(Waynezee): when moe_sharding_degree > 1, experts parameter's master_weight may exist in ranks of another moe_sharding_rank.
if name in master_weights:
offload(master_weights[name])
return moe_weights_name


def _reload_moe_expert_master_weights(optimizer, moe_weights_name):
master_weights = getattr(optimizer, "_master_weights", {})
for name in moe_weights_name:
if name in master_weights:
reload(master_weights[name])


class FP8QuantWeightCallback(TrainerCallback):
"""
Callback for FP8 weight quantization during training
"""

def __init__(self):
self.moe_weights_name = []

def on_step_begin(self, args, state, control, **kwargs):
"""
Quantize expert weights to FP8 before each training step
"""
if args.using_sonic_moe:
# sonicmoe cannot support offline quant now.
return
model = kwargs["model"]
optimizer = kwargs["optimizer"]
global skip_count
Expand All @@ -731,30 +785,21 @@ def on_step_begin(self, args, state, control, **kwargs):
self.use_fp8 = model.use_fp8()
if not self.use_fp8:
return
model.fp8_quant_weight(True, quant_transpose=False)
optimizer.clear_param_storage("moe_expert")
optimizer.clear_param_storage("rms_linear")
optimizer.clear_param_storage("memory_attn")
optimizer.clear_param_storage("attn_out_project")
optimizer.clear_param_storage("shared_expert")
if not args.offload_fp8_expert_master_weight:
model.fp8_quant_weight(True, quant_transpose=True)
_clear_fp8_param_storage(optimizer)
if not getattr(args, "offload_fp8_expert_master_weight", False):
return
for param in optimizer._inner_opt._parameter_list:
color = getattr(param, "color", -1)
if isinstance(color, dict) and color["color"] == "moe_expert":
self.moe_weights_name.append(param.name)

for name in self.moe_weights_name:
# NOTE(Waynezee): when moe_sharding_degree > 1, experts parameter's master_weight may exist in ranks of another moe_sharding_rank.
if name in optimizer._master_weights:
offload(optimizer._master_weights[name])
self.moe_weights_name = _offload_moe_expert_master_weights(optimizer)

skip_count += 1

def on_optimizer_begin(self, args, state, control, **kwargs):
"""
Reload weights before optimizer step
"""
if args.using_sonic_moe:
# sonicmoe cannot support offline quant now.
return
model = kwargs["model"]
optimizer = kwargs["optimizer"]
global skip_count
Expand All @@ -764,9 +809,7 @@ def on_optimizer_begin(self, args, state, control, **kwargs):
and hasattr(model, "fp8_quant_weight")
and not args.sharding_parallel_size <= 1
):
for name in self.moe_weights_name:
if name in optimizer._master_weights:
reload(optimizer._master_weights[name])
_reload_moe_expert_master_weights(optimizer, self.moe_weights_name)


class MoECorrectionBiasAdjustCallback(TrainerCallback):
Expand Down Expand Up @@ -1000,6 +1043,57 @@ def on_step_end(self, args, state, control, **kwargs):
logger.info(f"[EMAStateAssembler] Assembling EMA state took {duration:.3f} seconds.")


class SonicMoELayoutSwitchCallback(TrainerCallback):
def __init__(self):
self._expert_storage_cleared = False
self.moe_weights_name = []

def _apply_to_sonic_moe_experts(self, model, fn_name):
def apply_layout_switch(layer):
if isinstance(layer, SonicMoEExpert):
getattr(layer, fn_name)()

model.apply(apply_layout_switch)

def _prepare_sonic_moe_fp8_weights(self, model):
SonicMoEExpert.clear_fp8_weight_cache()

def prepare_fp8_weights(layer):
if isinstance(layer, SonicMoEExpert):
layer.convert_weights_to_sonic_layout()
layer.quant_weight()

model.apply(prepare_fp8_weights)

def on_step_begin(self, args, state, control, **kwargs):
if args.using_sonic_moe:
self._expert_storage_cleared = False
self.moe_weights_name = []
if args.fp8:
self._prepare_sonic_moe_fp8_weights(kwargs["model"])
optimizer = kwargs.get("optimizer")
clear_storage = os.environ.get("SONIC_MOE_CLEAR_FP8_STORAGE", "1")
if clear_storage != "0" and optimizer is not None and not args.sharding_parallel_size <= 1:
colors = None
if clear_storage != "1":
colors = tuple(color.strip() for color in clear_storage.split(",") if color.strip())
effective_colors = _FP8_STORAGE_COLORS if colors is None else colors
_clear_fp8_param_storage(optimizer, effective_colors)
self._expert_storage_cleared = "moe_expert" in effective_colors
offload_master_weight = getattr(args, "offload_fp8_expert_master_weight", False)
if self._expert_storage_cleared and offload_master_weight:
self.moe_weights_name = _offload_moe_expert_master_weights(optimizer)
else:
self._apply_to_sonic_moe_experts(kwargs["model"], "convert_weights_to_sonic_layout")

def on_optimizer_begin(self, args, state, control, **kwargs):
if args.using_sonic_moe:
if getattr(args, "offload_fp8_expert_master_weight", False):
_reload_moe_expert_master_weights(kwargs["optimizer"], self.moe_weights_name)
if not self._expert_storage_cleared:
self._apply_to_sonic_moe_experts(kwargs["model"], "convert_weights_to_grouped_layout")


class InterleaveGateUpCallback(TrainerCallback):
def __init__(self, model, resume_from_checkpoint=None, output_dir=None):
self.model = model
Expand Down
6 changes: 6 additions & 0 deletions paddleformers/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ class LlmMetaConfig:
True,
"Whether to use FP8 for gradient storage during training (only effective if `fp8=True`). Further reduces memory footprint but may introduce minor numerical error. Defaults to False.",
),
(
"use_ue8m0",
bool,
False,
"Whether to use UE8M0 packed scaling factors for FP8 on Blackwell GPUs (SM100+). Enables deep_gemm backend for weight gradient computation. Defaults to False.",
),
]

model_conf = [
Expand Down
Loading