fix: also check grad norm

This commit is contained in:
NanoCode012
2025-10-27 16:10:54 +07:00
parent d4f50806cd
commit 8cfc09d958

View File

@@ -153,12 +153,13 @@ class TelemetryCallback(TrainerCallback):
self.last_report_step = step
def _extract_last_metrics(self, state: TrainerState) -> dict:
"""Extract last loss and learning_rate from log history."""
"""Extract last loss, learning_rate, and grad_norm from log history."""
if not state.log_history:
return {"loss": 0, "learning_rate": 0}
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
last_log = state.log_history[-1]
return {
"loss": last_log.get("loss", 0),
"learning_rate": last_log.get("learning_rate", 0),
"grad_norm": last_log.get("grad_norm", 0),
}