From 3b44989205507493961ad66962466852106f8b1a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 6 Dec 2024 12:19:14 -0500 Subject: [PATCH] skip parent, call grandparent - yeah, super janky --- src/axolotl/core/trainer_builder.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cfc412573..5418e53bd 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1165,7 +1165,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(DPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): @@ -1183,7 +1185,9 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): @@ -1228,7 +1232,9 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(KTOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): @@ -1246,7 +1252,9 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(CPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): @@ -1264,7 +1272,9 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(RewardTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class TrainerBuilderBase(abc.ABC):