From 86a5803212f8768782d886fe767d97433e298c76 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:14:55 +0530 Subject: [PATCH] train_per_sec_per_gpu metric (#3364) [skip ci] * fix token count * guard for none n zero --- .../utils/callbacks/tokens_per_second.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index 679b1c864..e3a3ce333 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -78,12 +78,19 @@ class TokensPerSecondCallback(TrainerCallback): **kwargs, ): # pylint: disable=unused-argument tokens = getattr(state, "tokens", None) - if tokens and "trainable_tokens" in tokens: - step_time = time.perf_counter() - self.start_time - num_tokens_per_device = tokens["trainable_tokens"].clone() - # non data parallel groups have duplicated tokens, so we avoid double-counting - num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size - state.last_tokens_per_second = num_tokens_per_device / step_time + if not (tokens and "trainable_tokens" in tokens): + return + step_time = time.perf_counter() - self.start_time + if step_time <= 0: + return + + num_tokens = tokens["trainable_tokens"].clone() / self.non_data_parallel_size + if torch.distributed.is_initialized(): + dp_size = max( + 1, torch.distributed.get_world_size() // self.non_data_parallel_size + ) + num_tokens = num_tokens / dp_size + state.last_tokens_per_second = num_tokens / step_time def on_log( self,