Tokens per second logging [skip-e2e] (#3072)

This commit is contained in:
salman
2025-08-27 09:10:14 +01:00
committed by GitHub
parent e1131e9619
commit d0d2fc5606
6 changed files with 109 additions and 15 deletions

View File

@@ -24,9 +24,7 @@ from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -38,6 +36,7 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -146,6 +145,12 @@ class TrainerBuilderBase(abc.ABC):
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
@@ -512,6 +517,7 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -88,7 +88,6 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
@@ -337,6 +336,17 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -536,9 +546,6 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -586,12 +593,19 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)

View File

@@ -49,6 +49,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -60,13 +60,14 @@ def gpu_memory_usage_all(device=0):
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
torch.cuda.reset_peak_memory_stats(device)
return active, allocated, reserved
def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0
active = torch.mps.current_allocated_memory() / 1024.0**3
allocated = torch.mps.driver_allocated_memory() / 1024.0**3
return active, allocated, 0
def npu_memory_usage_all(device=0):

View File

@@ -0,0 +1,62 @@
"""A callback for calculating tokens per second during training."""
import time
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
class TokensPerSecondCallback(TrainerCallback):
"""
A callback to measure and log tokens per second during training.
"""
def __init__(self, tensor_parallel_size, context_parallel_size):
super().__init__()
self.step_time = 0.0
self.start_time = 0.0
self.non_data_parallel_size = 1
if tensor_parallel_size is not None:
self.non_data_parallel_size *= tensor_parallel_size
if context_parallel_size is not None:
self.non_data_parallel_size *= context_parallel_size
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
self.start_time = time.perf_counter()
state.last_tokens_per_second = torch.zeros(1)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
): # pylint: disable=unused-argument
# after logging, clear the running metrics
state.last_tokens_per_second.zero_()
state.num_tokens = 0

View File

@@ -830,10 +830,15 @@ class AxolotlInputConfig(
include_tokens_per_second: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
"description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets."
},
)
include_tkps: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
},
)
neftune_noise_alpha: float | None = Field(
default=None,
json_schema_extra={