fix total/trainable tokens log (#3344)

* fix total/trainable tokens log

* fix total/trainable tokens log
This commit is contained in:
VED
2026-01-06 19:55:17 +05:30
committed by GitHub
parent 8aab807e67
commit 7bf6f70e96
2 changed files with 4 additions and 11 deletions

View File

@@ -660,11 +660,10 @@ class AxolotlTrainer(
logs["tokens/train_per_sec_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
if (
hasattr(self.state, "total_tokens")
and self.state.total_tokens is not None
):
logs["total_tokens"] = int(self.state.total_tokens.item())
if "total" in self.state.tokens:
logs["tokens/total"] = int(self.state.tokens["total"].item())
if "trainable" in self.state.tokens:
logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
del self._stored_metrics[train_eval]

View File

@@ -101,9 +101,3 @@ class TokensPerSecondCallback(TrainerCallback):
# Clear per-step tokens after logging
if tokens and "trainable_tokens" in tokens:
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
if tokens and "total" in tokens:
logs["tokens/total"] = tokens["total"].item()
if tokens and "trainable" in tokens:
logs["tokens/trainable"] = tokens["trainable"].item()