From 86cf62ca464cb3af7a5bdf40a2eb73fcdb70f2e1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 25 Nov 2024 18:31:43 +0700 Subject: [PATCH] fix: update trainer.log signature --- src/axolotl/core/trainer_builder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 93384189e..f681f0622 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -957,13 +957,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer): return res - def log(self, logs: Dict[str, float]) -> None: + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: """ Log `logs` on the various objects watching training, including stored metrics. Args: logs (`Dict[str, float]`): The values to log. + start_time (`Optional[float]`): + The start of training. """ # logs either has 'loss' or 'eval_loss' train_eval = "train" if "loss" in logs else "eval" @@ -971,7 +973,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs) + return super().log(logs, start_time) def store_metrics( self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"