compute loss only if training and update token metric naming (#3293) [skip ci]

* compute loss only if training

* save total_tokens for checkpiont

* check if string

* refactor total_tokens/ num_tokens

* refactor 2

* rplc trainable_step/trian_per_sec_per_gpu

* lint + log trainable/tokens

* consolidate it in the callback.

* test for total_tokes aftr remuse

* check if tokenstate exist after ckpt

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
This commit is contained in:
VED
2025-12-25 17:08:17 +05:30
committed by GitHub
parent 4f5e8a328a
commit a6080df73c
4 changed files with 128 additions and 20 deletions

View File

@@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.include_tkps: if self.cfg.include_tkps:
callbacks.append( callbacks.append(
TokensPerSecondCallback( TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
) )
) )
return callbacks return callbacks

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json
import math import math
import os import os
from collections import defaultdict from collections import defaultdict
@@ -50,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__) LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = { REDUCTION_FNS = {
"mean": torch.mean, "mean": torch.mean,
"min": torch.min, "min": torch.min,
@@ -349,24 +352,34 @@ class AxolotlTrainer(
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation # track number of tokens for tokens per second calculation
if self.args.include_tkps: if self.args.include_tkps and model.training:
inputs_key = "labels" if "labels" in inputs else "input_ids" inputs_key = "labels" if "labels" in inputs else "input_ids"
num_tokens = (inputs[inputs_key] != -100).sum() trainable_tokens = (inputs[inputs_key] != -100).sum()
total_tokens = inputs[inputs_key].numel()
if is_distributed(): if is_distributed():
torch.distributed.all_reduce( torch.distributed.all_reduce(
num_tokens, op=torch.distributed.ReduceOp.SUM trainable_tokens, op=torch.distributed.ReduceOp.SUM
) )
if hasattr(self.state, "num_tokens"): torch.distributed.all_reduce(
self.state.num_tokens = ( total_tokens, op=torch.distributed.ReduceOp.SUM
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
) )
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if hasattr(self.state, "total_tokens"): if not hasattr(self.state, "tokens"):
self.state.total_tokens += num_tokens self.state.tokens = {
else: "trainable": torch.zeros(1),
self.state.total_tokens = num_tokens "total": torch.zeros(1),
}
# trainable tokens for throughput and total token slots for summaries
self.state.tokens["trainable"] = (
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
)
self.state.tokens["total"] = (
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
)
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
if self.args.orpo_alpha: if self.args.orpo_alpha:
return self.orpo_compute_loss( return self.orpo_compute_loss(
@@ -638,10 +651,14 @@ class AxolotlTrainer(
except (ValueError, TypeError, FileNotFoundError): except (ValueError, TypeError, FileNotFoundError):
pass pass
if self.args.include_tkps and train_eval == "train": if (
self.args.include_tkps
and train_eval == "train"
and hasattr(self.state, "tokens")
):
# each rank will log its own tokens per second # each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric # for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round( logs["tokens/train_per_sec_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
) )
if ( if (
@@ -683,6 +700,19 @@ class AxolotlTrainer(
run_dir = self._get_output_dir(trial=trial) run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Save total_tokens state if tracking is enabled
if self.args.include_tkps and hasattr(self.state, "tokens"):
tokens_state = {
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
"trainable": int(
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
),
}
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
with open(tokens_state_path, "w", encoding="utf-8") as f:
json.dump(tokens_state, f)
return super()._save_checkpoint(model, trial, **kwargs) return super()._save_checkpoint(model, trial, **kwargs)
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged

View File

@@ -1,5 +1,7 @@
"""A callback for calculating tokens per second during training.""" """A callback for calculating tokens per second during training."""
import json
import os
import time import time
import torch import torch
@@ -10,22 +12,52 @@ from transformers import (
TrainingArguments, TrainingArguments,
) )
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state.json"
class TokensPerSecondCallback(TrainerCallback): class TokensPerSecondCallback(TrainerCallback):
""" """
A callback to measure and log tokens per second during training. A callback to measure and log tokens per second during training.
Also handles saving/restoring total_tokens state across checkpoint resumes.
""" """
def __init__(self, tensor_parallel_size, context_parallel_size): def __init__(
self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None
):
super().__init__() super().__init__()
self.step_time = 0.0 self.step_time = 0.0
self.start_time = 0.0 self.start_time = 0.0
self.non_data_parallel_size = 1 self.non_data_parallel_size = 1
self.resume_from_checkpoint = resume_from_checkpoint
if tensor_parallel_size is not None: if tensor_parallel_size is not None:
self.non_data_parallel_size *= tensor_parallel_size self.non_data_parallel_size *= tensor_parallel_size
if context_parallel_size is not None: if context_parallel_size is not None:
self.non_data_parallel_size *= context_parallel_size self.non_data_parallel_size *= context_parallel_size
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
"""Restore total_tokens state when resuming from checkpoint."""
if not isinstance(self.resume_from_checkpoint, str):
return
tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE)
if os.path.isfile(tokens_state_path):
with open(tokens_state_path, "r", encoding="utf-8") as f:
tokens_state = json.load(f)
state.tokens = {
"total": torch.tensor(tokens_state.get("total", 0)),
"trainable": torch.tensor(tokens_state.get("trainable", 0)),
}
LOG.info(f"Restored total_tokens: {state.tokens['total']}")
def on_step_begin( def on_step_begin(
self, self,
args: TrainingArguments, args: TrainingArguments,
@@ -33,6 +65,8 @@ class TokensPerSecondCallback(TrainerCallback):
control: TrainerControl, control: TrainerControl,
**kwargs, **kwargs,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
if not hasattr(state, "tokens"):
state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)}
self.start_time = time.perf_counter() self.start_time = time.perf_counter()
state.last_tokens_per_second = torch.zeros(1) state.last_tokens_per_second = torch.zeros(1)
@@ -43,9 +77,10 @@ class TokensPerSecondCallback(TrainerCallback):
control: TrainerControl, control: TrainerControl,
**kwargs, **kwargs,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
if hasattr(state, "num_tokens"): tokens = getattr(state, "tokens", None)
if tokens and "trainable_tokens" in tokens:
step_time = time.perf_counter() - self.start_time step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone() num_tokens_per_device = tokens["trainable_tokens"].clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting # non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time state.last_tokens_per_second = num_tokens_per_device / step_time
@@ -60,5 +95,15 @@ class TokensPerSecondCallback(TrainerCallback):
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# after logging, clear the running metrics # after logging, clear the running metrics
if hasattr(state, "last_tokens_per_second"): if hasattr(state, "last_tokens_per_second"):
logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item()
state.last_tokens_per_second.zero_() state.last_tokens_per_second.zero_()
state.num_tokens = torch.zeros(1) tokens = getattr(state, "tokens", None)
# Clear per-step tokens after logging
if tokens and "trainable_tokens" in tokens:
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
if tokens and "total" in tokens:
logs["tokens/total"] = tokens["total"].item()
if tokens and "trainable" in tokens:
logs["tokens/trainable"] = tokens["trainable"].item()

View File

@@ -2,6 +2,7 @@
E2E tests for resuming training E2E tests for resuming training
""" """
import os
import re import re
import subprocess import subprocess
@@ -9,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -58,6 +60,7 @@ class TestResumeLlama:
"use_tensorboard": True, "use_tensorboard": True,
"save_safetensors": True, "save_safetensors": True,
"save_first_step": False, "save_first_step": False,
"include_tkps": True,
} }
) )
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():
@@ -68,8 +71,19 @@ class TestResumeLlama:
normalize_config(cfg) normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg) dataset_meta = load_datasets(cfg=cfg)
initial_total_num_tokens = cfg.total_num_tokens
assert initial_total_num_tokens is not None, (
"total_num_tokens should be calculated during load_datasets"
)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
checkpoint_path = f"{temp_dir}/checkpoint-9"
tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE)
assert os.path.isfile(tokens_state_path), (
f"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}"
)
resume_cfg = cfg | DictDefault( resume_cfg = cfg | DictDefault(
{ {
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/", "resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
@@ -77,7 +91,24 @@ class TestResumeLlama:
) )
normalize_config(resume_cfg) normalize_config(resume_cfg)
train(cfg=resume_cfg, dataset_meta=dataset_meta) assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
f"total_num_tokens should be preserved on resume. "
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
)
resume_dataset_meta = load_datasets(cfg=resume_cfg)
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
f"total_num_tokens should not be recalculated when resuming. "
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
)
train(cfg=resume_cfg, dataset_meta=resume_dataset_meta)
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
f"total_num_tokens should remain unchanged after resume training. "
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")