From 84a14fc604632ec047d1655c6ebf60ae68957691 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 6 Dec 2024 10:35:29 -0500 Subject: [PATCH] fix trl trainer.log interfaces --- src/axolotl/core/trainer_builder.py | 77 +++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f681f0622..cfc412573 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1157,6 +1157,16 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): torch.cuda.empty_cache() return loss + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ @@ -1165,6 +1175,16 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): tag_names = ["axolotl", "orpo"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): """ @@ -1173,6 +1193,43 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): tag_names = ["axolotl", "kto"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # train metrics should have no prefix, eval should have 'eval_' + prefix = "eval_" if train_eval == "eval" else "" + # accumulate average metrics from sums and lengths + for split in ["chosen", "rejected"]: + if f"count/{split}" in self._stored_metrics[train_eval]: + count_sum = ( + torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]) + .sum() + .item() + ) + for metric in ["rewards", "logps", "logits"]: + logs[f"{prefix}{metric}/{split}"] = ( + torch.Tensor( + self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + ) + .sum() + .item() + / count_sum + ) + # delete obsolete metric + del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + del self._stored_metrics[train_eval][f"count/{split}"] + # calculate reward margin + if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: + logs[f"{prefix}rewards/margins"] = ( + logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] + ) + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): """ @@ -1181,6 +1238,16 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): tag_names = ["axolotl", "cpo"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): """ @@ -1189,6 +1256,16 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): tag_names = ["axolotl", "reward"] + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + # TODO remove once trl supports the updated to the Trainer.log method + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + class TrainerBuilderBase(abc.ABC): """