train_per_sec_per_gpu metric (#3364) [skip ci]

* fix token count

* guard for none n zero
This commit is contained in:
VED
2026-02-10 16:14:55 +05:30
committed by GitHub
parent 530a0c0bf0
commit 86a5803212

View File

@@ -78,12 +78,19 @@ class TokensPerSecondCallback(TrainerCallback):
**kwargs,
): # pylint: disable=unused-argument
tokens = getattr(state, "tokens", None)
if tokens and "trainable_tokens" in tokens:
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = tokens["trainable_tokens"].clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
if not (tokens and "trainable_tokens" in tokens):
return
step_time = time.perf_counter() - self.start_time
if step_time <= 0:
return
num_tokens = tokens["trainable_tokens"].clone() / self.non_data_parallel_size
if torch.distributed.is_initialized():
dp_size = max(
1, torch.distributed.get_world_size() // self.non_data_parallel_size
)
num_tokens = num_tokens / dp_size
state.last_tokens_per_second = num_tokens / step_time
def on_log(
self,