fix: also check grad norm
This commit is contained in:
@@ -153,12 +153,13 @@ class TelemetryCallback(TrainerCallback):
|
||||
self.last_report_step = step
|
||||
|
||||
def _extract_last_metrics(self, state: TrainerState) -> dict:
|
||||
"""Extract last loss and learning_rate from log history."""
|
||||
"""Extract last loss, learning_rate, and grad_norm from log history."""
|
||||
if not state.log_history:
|
||||
return {"loss": 0, "learning_rate": 0}
|
||||
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
|
||||
|
||||
last_log = state.log_history[-1]
|
||||
return {
|
||||
"loss": last_log.get("loss", 0),
|
||||
"learning_rate": last_log.get("learning_rate", 0),
|
||||
"grad_norm": last_log.get("grad_norm", 0),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user