fix total/trainable tokens log (#3344)
* fix total/trainable tokens log * fix total/trainable tokens log
This commit is contained in:
@@ -660,11 +660,10 @@ class AxolotlTrainer(
|
||||
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
if (
|
||||
hasattr(self.state, "total_tokens")
|
||||
and self.state.total_tokens is not None
|
||||
):
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
if "total" in self.state.tokens:
|
||||
logs["tokens/total"] = int(self.state.tokens["total"].item())
|
||||
if "trainable" in self.state.tokens:
|
||||
logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
|
||||
@@ -101,9 +101,3 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
# Clear per-step tokens after logging
|
||||
if tokens and "trainable_tokens" in tokens:
|
||||
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
|
||||
|
||||
if tokens and "total" in tokens:
|
||||
logs["tokens/total"] = tokens["total"].item()
|
||||
|
||||
if tokens and "trainable" in tokens:
|
||||
logs["tokens/trainable"] = tokens["trainable"].item()
|
||||
|
||||
Reference in New Issue
Block a user