diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index 1199313e8..ebe46f11d 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -1,5 +1,7 @@ """Module for TRL PPO trainer""" +from typing import Literal, Union + import torch from tqdm import tqdm from trl import ( @@ -79,6 +81,78 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): tag_names = ["axolotl", "orpo"] + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + + # TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release + + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = ( + self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps) + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics( + chosen_rewards + ).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics( + rejected_rewards + ).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics( + reward_accuracies + ).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( + chosen_rewards - rejected_rewards + ).mean() + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() + ) + metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( + policy_rejected_logits.detach().mean() + ).mean() + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( + policy_chosen_logits.detach().mean() + ).mean() + metrics[f"{prefix}nll_loss"] = ( + self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() + ) + metrics[f"{prefix}log_odds_ratio"] = ( + self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() + ) + metrics[f"{prefix}log_odds_chosen"] = ( + self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() + ) + for k, v in metrics.items(): + metrics[k] = v.item() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): """ @@ -95,6 +169,80 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): tag_names = ["axolotl", "cpo"] + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the CPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards = self.cpo_loss( + policy_chosen_logps, + policy_rejected_logps, + ) + + loss = losses.mean() + self.cpo_alpha * policy_nll_loss + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = ( + self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + ) + metrics[f"{prefix}rewards/rejected"] = ( + self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + ) + metrics[f"{prefix}rewards/accuracies"] = ( + self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + ) + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards) + .mean() + .item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps) + .detach() + .mean() + .item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps) + .detach() + .mean() + .item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()) + .mean() + .item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()) + .mean() + .item() + ) + metrics[f"{prefix}nll_loss"] = ( + self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item() + ) + + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): """