add override of upstream fix for multi-gpu orpo (#2440)
* add override of upstream fix * override batch loss metrics for CPO/Simpo as well
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
"""Module for TRL PPO trainer"""
|
"""Module for TRL PPO trainer"""
|
||||||
|
|
||||||
|
from typing import Literal, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import (
|
from trl import (
|
||||||
@@ -79,6 +81,78 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
||||||
|
|
||||||
|
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
forward_output = self.concatenated_forward(model, batch)
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = forward_output[:5]
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
aux_loss = forward_output[5]
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
|
||||||
|
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||||
|
)
|
||||||
|
# full ORPO loss
|
||||||
|
loss = policy_nll_loss - losses.mean()
|
||||||
|
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
|
||||||
|
chosen_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
|
||||||
|
rejected_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
|
||||||
|
reward_accuracies
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
||||||
|
chosen_rewards - rejected_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}logps/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
||||||
|
policy_rejected_logits.detach().mean()
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
||||||
|
policy_chosen_logits.detach().mean()
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}nll_loss"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}log_odds_ratio"] = (
|
||||||
|
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}log_odds_chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
||||||
|
)
|
||||||
|
for k, v in metrics.items():
|
||||||
|
metrics[k] = v.item()
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
loss += self.aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -95,6 +169,80 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
forward_output = self.concatenated_forward(model, batch)
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = forward_output[:5]
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
aux_loss = forward_output[5]
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = (
|
||||||
|
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/margins"] = (
|
||||||
|
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logps)
|
||||||
|
.detach()
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logps)
|
||||||
|
.detach()
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}nll_loss"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
loss += self.aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user