compute loss only if training and update token metric naming (#3293) [skip ci]

* compute loss only if training

* save total_tokens for checkpiont

* check if string

* refactor total_tokens/ num_tokens

* refactor 2

* rplc trainable_step/trian_per_sec_per_gpu

* lint + log trainable/tokens

* consolidate it in the callback.

* test for total_tokes aftr remuse

* check if tokenstate exist after ckpt

---------

Co-authored-by: Ved <ved.work2024@gmail.com>
This commit is contained in:
VED
2025-12-25 17:08:17 +05:30
committed by GitHub
parent 4f5e8a328a
commit a6080df73c
4 changed files with 128 additions and 20 deletions

View File

@@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
)
)
return callbacks

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import json
import math
import os
from collections import defaultdict
@@ -50,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
@@ -349,24 +352,34 @@ class AxolotlTrainer(
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
if self.args.include_tkps and model.training:
inputs_key = "labels" if "labels" in inputs else "input_ids"
num_tokens = (inputs[inputs_key] != -100).sum()
trainable_tokens = (inputs[inputs_key] != -100).sum()
total_tokens = inputs[inputs_key].numel()
if is_distributed():
torch.distributed.all_reduce(
num_tokens, op=torch.distributed.ReduceOp.SUM
trainable_tokens, op=torch.distributed.ReduceOp.SUM
)
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
torch.distributed.all_reduce(
total_tokens, op=torch.distributed.ReduceOp.SUM
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if hasattr(self.state, "total_tokens"):
self.state.total_tokens += num_tokens
else:
self.state.total_tokens = num_tokens
if not hasattr(self.state, "tokens"):
self.state.tokens = {
"trainable": torch.zeros(1),
"total": torch.zeros(1),
}
# trainable tokens for throughput and total token slots for summaries
self.state.tokens["trainable"] = (
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
)
self.state.tokens["total"] = (
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
)
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
@@ -638,10 +651,14 @@ class AxolotlTrainer(
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
if (
self.args.include_tkps
and train_eval == "train"
and hasattr(self.state, "tokens")
):
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
logs["tokens/train_per_sec_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
if (
@@ -683,6 +700,19 @@ class AxolotlTrainer(
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
# Save total_tokens state if tracking is enabled
if self.args.include_tkps and hasattr(self.state, "tokens"):
tokens_state = {
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
"trainable": int(
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
),
}
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
with open(tokens_state_path, "w", encoding="utf-8") as f:
json.dump(tokens_state, f)
return super()._save_checkpoint(model, trial, **kwargs)
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged

View File

@@ -1,5 +1,7 @@
"""A callback for calculating tokens per second during training."""
import json
import os
import time
import torch
@@ -10,22 +12,52 @@ from transformers import (
TrainingArguments,
)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state.json"
class TokensPerSecondCallback(TrainerCallback):
"""
A callback to measure and log tokens per second during training.
Also handles saving/restoring total_tokens state across checkpoint resumes.
"""
def __init__(self, tensor_parallel_size, context_parallel_size):
def __init__(
self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None
):
super().__init__()
self.step_time = 0.0
self.start_time = 0.0
self.non_data_parallel_size = 1
self.resume_from_checkpoint = resume_from_checkpoint
if tensor_parallel_size is not None:
self.non_data_parallel_size *= tensor_parallel_size
if context_parallel_size is not None:
self.non_data_parallel_size *= context_parallel_size
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
"""Restore total_tokens state when resuming from checkpoint."""
if not isinstance(self.resume_from_checkpoint, str):
return
tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE)
if os.path.isfile(tokens_state_path):
with open(tokens_state_path, "r", encoding="utf-8") as f:
tokens_state = json.load(f)
state.tokens = {
"total": torch.tensor(tokens_state.get("total", 0)),
"trainable": torch.tensor(tokens_state.get("trainable", 0)),
}
LOG.info(f"Restored total_tokens: {state.tokens['total']}")
def on_step_begin(
self,
args: TrainingArguments,
@@ -33,6 +65,8 @@ class TokensPerSecondCallback(TrainerCallback):
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
if not hasattr(state, "tokens"):
state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)}
self.start_time = time.perf_counter()
state.last_tokens_per_second = torch.zeros(1)
@@ -43,9 +77,10 @@ class TokensPerSecondCallback(TrainerCallback):
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
if hasattr(state, "num_tokens"):
tokens = getattr(state, "tokens", None)
if tokens and "trainable_tokens" in tokens:
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
num_tokens_per_device = tokens["trainable_tokens"].clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
@@ -60,5 +95,15 @@ class TokensPerSecondCallback(TrainerCallback):
): # pylint: disable=unused-argument
# after logging, clear the running metrics
if hasattr(state, "last_tokens_per_second"):
logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item()
state.last_tokens_per_second.zero_()
state.num_tokens = torch.zeros(1)
tokens = getattr(state, "tokens", None)
# Clear per-step tokens after logging
if tokens and "trainable_tokens" in tokens:
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
if tokens and "total" in tokens:
logs["tokens/total"] = tokens["total"].item()
if tokens and "trainable" in tokens:
logs["tokens/trainable"] = tokens["trainable"].item()