fix: also check grad norm
This commit is contained in:
@@ -153,12 +153,13 @@ class TelemetryCallback(TrainerCallback):
|
|||||||
self.last_report_step = step
|
self.last_report_step = step
|
||||||
|
|
||||||
def _extract_last_metrics(self, state: TrainerState) -> dict:
|
def _extract_last_metrics(self, state: TrainerState) -> dict:
|
||||||
"""Extract last loss and learning_rate from log history."""
|
"""Extract last loss, learning_rate, and grad_norm from log history."""
|
||||||
if not state.log_history:
|
if not state.log_history:
|
||||||
return {"loss": 0, "learning_rate": 0}
|
return {"loss": 0, "learning_rate": 0, "grad_norm": 0}
|
||||||
|
|
||||||
last_log = state.log_history[-1]
|
last_log = state.log_history[-1]
|
||||||
return {
|
return {
|
||||||
"loss": last_log.get("loss", 0),
|
"loss": last_log.get("loss", 0),
|
||||||
"learning_rate": last_log.get("learning_rate", 0),
|
"learning_rate": last_log.get("learning_rate", 0),
|
||||||
|
"grad_norm": last_log.get("grad_norm", 0),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user