From a6080df73c8ab991ec9ee2e000a2df1be14bb493 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Thu, 25 Dec 2025 17:08:17 +0530 Subject: [PATCH] compute loss only if training and update token metric naming (#3293) [skip ci] * compute loss only if training * save total_tokens for checkpiont * check if string * refactor total_tokens/ num_tokens * refactor 2 * rplc trainable_step/trian_per_sec_per_gpu * lint + log trainable/tokens * consolidate it in the callback. * test for total_tokes aftr remuse * check if tokenstate exist after ckpt --------- Co-authored-by: Ved --- src/axolotl/core/builders/causal.py | 4 +- src/axolotl/core/trainers/base.py | 58 ++++++++++++++----- .../utils/callbacks/tokens_per_second.py | 53 +++++++++++++++-- tests/e2e/patched/test_resume.py | 33 ++++++++++- 4 files changed, 128 insertions(+), 20 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 7a06431dc..cda98087f 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.include_tkps: callbacks.append( TokensPerSecondCallback( - self.cfg.tensor_parallel_size, self.cfg.context_parallel_size + self.cfg.tensor_parallel_size, + self.cfg.context_parallel_size, + resume_from_checkpoint=self.cfg.resume_from_checkpoint, ) ) return callbacks diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index aae3d28fb..850517ded 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import math import os from collections import defaultdict @@ -50,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) +TOKENS_STATE_FILE = "tokens_state." + REDUCTION_FNS = { "mean": torch.mean, "min": torch.min, @@ -349,24 +352,34 @@ class AxolotlTrainer( # return (loss, outputs) if return_outputs else loss # track number of tokens for tokens per second calculation - if self.args.include_tkps: + if self.args.include_tkps and model.training: inputs_key = "labels" if "labels" in inputs else "input_ids" - num_tokens = (inputs[inputs_key] != -100).sum() + trainable_tokens = (inputs[inputs_key] != -100).sum() + total_tokens = inputs[inputs_key].numel() + if is_distributed(): torch.distributed.all_reduce( - num_tokens, op=torch.distributed.ReduceOp.SUM + trainable_tokens, op=torch.distributed.ReduceOp.SUM ) - if hasattr(self.state, "num_tokens"): - self.state.num_tokens = ( - self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu() + torch.distributed.all_reduce( + total_tokens, op=torch.distributed.ReduceOp.SUM ) - else: - self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() - if hasattr(self.state, "total_tokens"): - self.state.total_tokens += num_tokens - else: - self.state.total_tokens = num_tokens + if not hasattr(self.state, "tokens"): + self.state.tokens = { + "trainable": torch.zeros(1), + "total": torch.zeros(1), + } + + # trainable tokens for throughput and total token slots for summaries + self.state.tokens["trainable"] = ( + self.state.tokens["trainable"] + trainable_tokens.detach().cpu() + ) + self.state.tokens["total"] = ( + self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu() + ) + # Store per-step trainable tokens for throughput calculation + self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu() if self.args.orpo_alpha: return self.orpo_compute_loss( @@ -638,10 +651,14 @@ class AxolotlTrainer( except (ValueError, TypeError, FileNotFoundError): pass - if self.args.include_tkps and train_eval == "train": + if ( + self.args.include_tkps + and train_eval == "train" + and hasattr(self.state, "tokens") + ): # each rank will log its own tokens per second # for logging_steps > 1 we obtain a moving average of this metric - logs["tokens_per_second_per_gpu"] = round( + logs["tokens/train_per_sec_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) if ( @@ -683,6 +700,19 @@ class AxolotlTrainer( run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) + + # Save total_tokens state if tracking is enabled + if self.args.include_tkps and hasattr(self.state, "tokens"): + tokens_state = { + "total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()), + "trainable": int( + torch.as_tensor(self.state.tokens.get("trainable", 0)).item() + ), + } + tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE) + with open(tokens_state_path, "w", encoding="utf-8") as f: + json.dump(tokens_state, f) + return super()._save_checkpoint(model, trial, **kwargs) # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index ead129240..a1b955a74 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -1,5 +1,7 @@ """A callback for calculating tokens per second during training.""" +import json +import os import time import torch @@ -10,22 +12,52 @@ from transformers import ( TrainingArguments, ) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +TOKENS_STATE_FILE = "tokens_state.json" + class TokensPerSecondCallback(TrainerCallback): """ A callback to measure and log tokens per second during training. + Also handles saving/restoring total_tokens state across checkpoint resumes. """ - def __init__(self, tensor_parallel_size, context_parallel_size): + def __init__( + self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None + ): super().__init__() self.step_time = 0.0 self.start_time = 0.0 self.non_data_parallel_size = 1 + self.resume_from_checkpoint = resume_from_checkpoint if tensor_parallel_size is not None: self.non_data_parallel_size *= tensor_parallel_size if context_parallel_size is not None: self.non_data_parallel_size *= context_parallel_size + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): # pylint: disable=unused-argument + """Restore total_tokens state when resuming from checkpoint.""" + if not isinstance(self.resume_from_checkpoint, str): + return + tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE) + if os.path.isfile(tokens_state_path): + with open(tokens_state_path, "r", encoding="utf-8") as f: + tokens_state = json.load(f) + state.tokens = { + "total": torch.tensor(tokens_state.get("total", 0)), + "trainable": torch.tensor(tokens_state.get("trainable", 0)), + } + LOG.info(f"Restored total_tokens: {state.tokens['total']}") + def on_step_begin( self, args: TrainingArguments, @@ -33,6 +65,8 @@ class TokensPerSecondCallback(TrainerCallback): control: TrainerControl, **kwargs, ): # pylint: disable=unused-argument + if not hasattr(state, "tokens"): + state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)} self.start_time = time.perf_counter() state.last_tokens_per_second = torch.zeros(1) @@ -43,9 +77,10 @@ class TokensPerSecondCallback(TrainerCallback): control: TrainerControl, **kwargs, ): # pylint: disable=unused-argument - if hasattr(state, "num_tokens"): + tokens = getattr(state, "tokens", None) + if tokens and "trainable_tokens" in tokens: step_time = time.perf_counter() - self.start_time - num_tokens_per_device = state.num_tokens.clone() + num_tokens_per_device = tokens["trainable_tokens"].clone() # non data parallel groups have duplicated tokens, so we avoid double-counting num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size state.last_tokens_per_second = num_tokens_per_device / step_time @@ -60,5 +95,15 @@ class TokensPerSecondCallback(TrainerCallback): ): # pylint: disable=unused-argument # after logging, clear the running metrics if hasattr(state, "last_tokens_per_second"): + logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item() state.last_tokens_per_second.zero_() - state.num_tokens = torch.zeros(1) + tokens = getattr(state, "tokens", None) + # Clear per-step tokens after logging + if tokens and "trainable_tokens" in tokens: + tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"]) + + if tokens and "total" in tokens: + logs["tokens/total"] = tokens["total"].item() + + if tokens and "trainable" in tokens: + logs["tokens/trainable"] = tokens["trainable"].item() diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 747b79dc7..e6240f208 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -2,6 +2,7 @@ E2E tests for resuming training """ +import os import re import subprocess @@ -9,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available from axolotl.common.datasets import load_datasets from axolotl.train import train +from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ -58,6 +60,7 @@ class TestResumeLlama: "use_tensorboard": True, "save_safetensors": True, "save_first_step": False, + "include_tkps": True, } ) if is_torch_bf16_gpu_available(): @@ -68,8 +71,19 @@ class TestResumeLlama: normalize_config(cfg) dataset_meta = load_datasets(cfg=cfg) + initial_total_num_tokens = cfg.total_num_tokens + assert initial_total_num_tokens is not None, ( + "total_num_tokens should be calculated during load_datasets" + ) + train(cfg=cfg, dataset_meta=dataset_meta) + checkpoint_path = f"{temp_dir}/checkpoint-9" + tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE) + assert os.path.isfile(tokens_state_path), ( + f"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}" + ) + resume_cfg = cfg | DictDefault( { "resume_from_checkpoint": f"{temp_dir}/checkpoint-9/", @@ -77,7 +91,24 @@ class TestResumeLlama: ) normalize_config(resume_cfg) - train(cfg=resume_cfg, dataset_meta=dataset_meta) + assert resume_cfg.total_num_tokens == initial_total_num_tokens, ( + f"total_num_tokens should be preserved on resume. " + f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}" + ) + + resume_dataset_meta = load_datasets(cfg=resume_cfg) + + assert resume_cfg.total_num_tokens == initial_total_num_tokens, ( + f"total_num_tokens should not be recalculated when resuming. " + f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}" + ) + + train(cfg=resume_cfg, dataset_meta=resume_dataset_meta) + + assert resume_cfg.total_num_tokens == initial_total_num_tokens, ( + f"total_num_tokens should remain unchanged after resume training. " + f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}" + ) check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")