diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index cfc412573..5418e53bd 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1165,7 +1165,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(DPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): @@ -1183,7 +1185,9 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): @@ -1228,7 +1232,9 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(KTOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): @@ -1246,7 +1252,9 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(CPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): @@ -1264,7 +1272,9 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + return super(RewardTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) class TrainerBuilderBase(abc.ABC):