Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/config/musa/ERNIE-4.5-21B-A3B/sft/21b_8_gpus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
device: musa

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# stage
stage: ernie_pretrain

# data
dataset_type: pretrain
input_dir: "1.0 ./data/eb45_industrycorpus2_94k"
split: "998,1,1"
max_seq_len: 2048
dataloader_num_workers: 8
prefetch_factor: 32
ignore_data_skip: 0

# model
model_name_or_path: model_configs/ERNIE-4p5-21B-A3B/
tokenizer_name_or_path: examples/experiments/ernie_pretrain/ernie/src/tokenizers/tokenizer_model
moe_router_bias_update_rate: 0.001
moe_with_send_router_loss: False
moe_group: ep
use_moe: true

# train
output_dir: ./output/
num_consecutive: 32
global_logging_interval: 1
enable_mtp_magic_send: True
do_train: True
overwrite_output_dir: 1
disable_tqdm: 1
logging_steps: 1
eval_steps: 1000000
eval_iters: -1
save_steps: 50
max_steps: 500
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-8
learning_rate: 3.14e-4
min_lr: 3.14e-6
gradient_accumulation_steps: 8
per_device_train_batch_size: 1
lr_scheduler: wsd:603000
decay_function: 1-sqrt
max_grad_norm: 1.0
weight_decay: 0.1
warmup_steps: 2000
save_total_limit: 5
bf16: True
fp16_opt_level: "O2"
scale_loss: 4096
seed: 42
pre_alloc_memory: 60
tensorwise_offload_optimizer: true
use_expert_parallel: True
expert_model_parallel_size: 8
sharding: "stage1"
sharding_parallel_size: 8
amp_master_grad: 1
report_to: none
global_batch_size: 128
use_ortho_loss_callback: true
same_data: True
continue_training: False

# sharding_parallel_config
split_param: true
sharding_comm_buffer_size_MB: 2048

# tensor_parallel_config
skip_profile_timer: False
load_sharded_model: True
save_sharded_model: True
ignore_load_lr_and_optim: False
gc_interval: 100000

# ernie_model_config
ernie_model_config:
use_quant_before_a2a: true
use_async_a2a: false
use_rms_qkv_recompute: true
moe_logging: true
use_recompute: false
num_nextn_predict_layers: 1
use_fp8_mlp: false
num_hidden_layers: 28
num_empty_layers_add_in_tail: 0
use_fp8_fuse_node: false
use_ep_comm_overlap: false
use_combine_before_a2a: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
device: musa

### data
train_dataset_type: messages
eval_dataset_type: messages
train_dataset_path: ./ocr_vl_sft-train_Bengali.jsonl
train_dataset_prob: "1.0"
eval_dataset_path: ./ocr_vl_sft-test_Bengali.jsonl
eval_dataset_prob: "1.0"
max_seq_len: 16384
padding_free: True
truncate_packing: False
dataloader_num_workers: 8
mix_strategy: concat
template_backend: custom
template: paddleocr_vl

### model
model_name_or_path: PaddlePaddle/PaddleOCR-VL
_attn_implementation: eager
copy_custom_file_list: "configuration_paddleocr_vl.py image_processing_paddleocr_vl.py modeling_paddleocr_vl.py processing_paddleocr_vl.py inference.yml"

### finetuning
# base
stage: VL-SFT
fine_tuning: full
seed: 23
do_train: true
do_eval: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
num_train_epochs: 2
max_steps: -1
max_estimate_samples: 500
eval_steps: 400
evaluation_strategy: steps
save_steps: 400
save_strategy: steps
logging_steps: 1
gradient_accumulation_steps: 8
logging_dir: ./PaddleOCR-VL-SFT-Bengali/visualdl_logs/
output_dir: ./PaddleOCR-VL-SFT-Bengali
disable_tqdm: true
eval_accumulation_steps: 16

# train
lr_scheduler_type: cosine
warmup_ratio: 0.01
learning_rate: 5.0e-6
min_lr: 5.0e-7

# optimizer
weight_decay: 0.1
adam_epsilon: 1.0e-8
adam_beta1: 0.9
adam_beta2: 0.95

# performance
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
sharding: stage1
recompute_granularity: full
recompute_method: uniform
recompute_num_layers: 1
bf16: true
fp16_opt_level: O2
pre_alloc_memory: 0

# save
unified_checkpoint: False
save_checkpoint_format: "flex_checkpoint"
load_checkpoint_format: "flex_checkpoint"
6 changes: 5 additions & 1 deletion paddleformers/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def main():
num_iluvatar_gpus = len(paddle.device.get_available_custom_device())
default_iluvatar_gpus = ",".join(map(str, range(0, num_iluvatar_gpus)))
visible_cards = os.getenv("CUDA_VISIBLE_DEVICES", default_iluvatar_gpus)
elif current_device == "musa":
num_musas = len(paddle.device.get_available_custom_device())
default_musas = ",".join(map(str, range(0, num_musas)))
visible_cards = os.getenv("MUSA_VISIBLE_DEVICES", default_musas)
else:
import GPUtil

Expand Down Expand Up @@ -150,7 +154,7 @@ def main():
# launch distributed training
env = deepcopy(os.environ)
args_to_pass = " ".join(shlex.quote(arg) for arg in sys.argv[1:])
if current_device == "iluvatar_gpu":
if current_device == "iluvatar_gpu" or current_device == "musa":
current_device = "gpu"
command = (
f"python -m paddle.distributed.launch --log_dir {paddleformers_dist_log} "
Expand Down
11 changes: 10 additions & 1 deletion paddleformers/cli/train/ernie_pretrain/models/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@

import numpy
import paddle
from paddle.incubate.fp8 import deep_gemm
import warnings

try:
from paddle.incubate.fp8 import deep_gemm
except ImportError:
warnings.warn(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议直接使用formers自带的logger

"paddle.incubate.fp8.deep_gemm is not available.",
RuntimeWarning,
)
deep_gemm = None
from paddle.nn.functional import swiglu

# Keep reference to original linear op for fallback if needed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@

import numpy
import paddle
from paddle.incubate.fp8 import deep_gemm
import warnings

try:
from paddle.incubate.fp8 import deep_gemm
except ImportError:
warnings.warn(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用formers自带的logger

"paddle.incubate.fp8.deep_gemm is not available.",
RuntimeWarning,
)
deep_gemm = None
from paddle.nn.functional import swiglu

from paddleformers.cli.train.ernie_pretrain.models.fp8_linear import fp8_gemm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
if not self.args.enable_global_training_logs:
global_training_logs.global_meters_keys = []

if get_env_device() == "gpu":
if get_env_device() == "gpu" or get_env_device() == "musa":
info_callback = global_training_logs.dict(use_async=True)

if hasattr(self, "scaler"):
Expand Down
2 changes: 2 additions & 0 deletions paddleformers/cli/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def detect_device() -> str:
return "xpu"
elif "iluvatar" in place_lower:
return "iluvatar_gpu"
elif "musa" in place_lower:
return "musa"
else:
return "gpu"
except Exception as e:
Expand Down
Loading