fix total/trainable tokens log (#3344)
* fix total/trainable tokens log * fix total/trainable tokens log
This commit is contained in:
@@ -660,11 +660,10 @@ class AxolotlTrainer(
|
|||||||
logs["tokens/train_per_sec_per_gpu"] = round(
|
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||||
)
|
)
|
||||||
if (
|
if "total" in self.state.tokens:
|
||||||
hasattr(self.state, "total_tokens")
|
logs["tokens/total"] = int(self.state.tokens["total"].item())
|
||||||
and self.state.total_tokens is not None
|
if "trainable" in self.state.tokens:
|
||||||
):
|
logs["tokens/trainable"] = int(self.state.tokens["trainable"].item())
|
||||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
|
|||||||
@@ -101,9 +101,3 @@ class TokensPerSecondCallback(TrainerCallback):
|
|||||||
# Clear per-step tokens after logging
|
# Clear per-step tokens after logging
|
||||||
if tokens and "trainable_tokens" in tokens:
|
if tokens and "trainable_tokens" in tokens:
|
||||||
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
|
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
|
||||||
|
|
||||||
if tokens and "total" in tokens:
|
|
||||||
logs["tokens/total"] = tokens["total"].item()
|
|
||||||
|
|
||||||
if tokens and "trainable" in tokens:
|
|
||||||
logs["tokens/trainable"] = tokens["trainable"].item()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user