skip parent, call grandparent - yeah, super janky
This commit is contained in:
@@ -1165,7 +1165,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs, start_time)
|
||||
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
@@ -1183,7 +1185,9 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs, start_time)
|
||||
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
@@ -1228,7 +1232,9 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs, start_time)
|
||||
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
@@ -1246,7 +1252,9 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs, start_time)
|
||||
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
@@ -1264,7 +1272,9 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs, start_time)
|
||||
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
Reference in New Issue
Block a user