Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions paddleformers/cli/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,11 +761,11 @@ def fetch_and_serialize(generator, dtype):
total_tokens / train_result.metrics["train_runtime"] / training_args.world_size
)
logger.info(f"Total_Tokens_per_second_per_gpu: {total_tokens_per_second_per_gpu} ")
if not training_args.autotuner_benchmark:
trainer.save_model(merge_tensor_parallel=training_args.tensor_model_parallel_size > 1, last_fc_to_hf=True)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
# if not training_args.autotuner_benchmark:
# trainer.save_model(merge_tensor_parallel=training_args.tensor_model_parallel_size > 1, last_fc_to_hf=True)
# trainer.log_metrics("train", train_result.metrics)
# trainer.save_metrics("train", train_result.metrics)
# trainer.save_state()


def create_peft_model(model_args, training_args, dtype, model):
Expand Down
47 changes: 21 additions & 26 deletions paddleformers/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import json
import math
import os
from typing import List
Expand Down Expand Up @@ -494,9 +495,28 @@ def collate_fn(
if padding_free:
batch = [sum(batch, [])]
max_seq_len = sum(len(item.token_ids) for sequence in batch for item in sequence)
fixed_tokens_json_path = os.environ.get("DSV4_FLEET_FIXED_TOKENS")
fixed_tokens_path = os.environ.get("LOAD_FIXED_DATA_PATH")
fixed_tokens = None
if fixed_tokens_path:
if fixed_tokens_json_path:
with open(fixed_tokens_json_path, "r", encoding="utf-8") as f:
fixed_payload = json.load(f)
fixed_token_ids = fixed_payload["tokens"] if isinstance(fixed_payload, dict) else fixed_payload
fixed_token_ids = [int(token) for token in fixed_token_ids]
expected_token_count = training_args.max_seq_len + mtp_depth
if len(fixed_token_ids) != expected_token_count:
raise ValueError(
f"DSV4_FLEET_FIXED_TOKENS expects {expected_token_count} tokens "
f"for max_seq_len={training_args.max_seq_len} and "
f"num_nextn_predict_layers={mtp_depth}, "
f"got {len(fixed_token_ids)} from {fixed_tokens_json_path}"
)
fixed_input_ids = fixed_token_ids[:-mtp_depth] if mtp_depth > 0 else fixed_token_ids
fixed_labels = fixed_token_ids[mtp_depth:] if mtp_depth > 0 else fixed_token_ids
fixed_position_ids = list(range(len(fixed_input_ids)))
max_seq_len = calc_padding_size(len(fixed_input_ids), training_args)
fixed_tokens = True
elif fixed_tokens_path:
rank = paddle.distributed.get_rank() if paddle.distributed.is_initialized() else 0
seq_len = training_args.max_seq_len
suffix = f"step0_rank{rank}_seq{seq_len}.npy"
Expand Down Expand Up @@ -605,31 +625,6 @@ def collate_fn(

return_list = [np.concatenate(tensor_list) for tensor_list in zip(*return_list)]
input_dict = dict(zip(input_keys, return_list))
if fixed_tokens is not None and (
os.environ.get("LOG_DATA_MD5", "0") == "1" or os.environ.get("LOG_LAYER_MD5", "0") == "1"
):
import hashlib

try:
rank = paddle.distributed.get_rank()
except Exception:
rank = 0
main_input = np.asarray([fixed_input_ids], dtype=np.int64)
main_labels = np.asarray([fixed_labels], dtype=np.int64)
print(
f"[LOAD_FIXED_DATA_PATH] loaded from {fixed_tokens_path}",
flush=True,
)
print(
f"[DATA_PATH_MD5] rank={rank} input_ids shape={list(main_input.shape)} "
f"md5={hashlib.md5(main_input.tobytes()).hexdigest()}",
flush=True,
)
print(
f"[DATA_PATH_MD5] rank={rank} labels shape={list(main_labels.shape)} "
f"md5={hashlib.md5(main_labels.tobytes()).hexdigest()}",
flush=True,
)
return input_dict


Expand Down
Loading
Loading