diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 3a08d0574..799dcf02e 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -660,11 +660,10 @@ class AxolotlTrainer( logs["tokens/train_per_sec_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) - if ( - hasattr(self.state, "total_tokens") - and self.state.total_tokens is not None - ): - logs["total_tokens"] = int(self.state.total_tokens.item()) + if "total" in self.state.tokens: + logs["tokens/total"] = int(self.state.tokens["total"].item()) + if "trainable" in self.state.tokens: + logs["tokens/trainable"] = int(self.state.tokens["trainable"].item()) del self._stored_metrics[train_eval] diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index a1b955a74..679b1c864 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -101,9 +101,3 @@ class TokensPerSecondCallback(TrainerCallback): # Clear per-step tokens after logging if tokens and "trainable_tokens" in tokens: tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"]) - - if tokens and "total" in tokens: - logs["tokens/total"] = tokens["total"].item() - - if tokens and "trainable" in tokens: - logs["tokens/trainable"] = tokens["trainable"].item()