From 8428b3f2c7296ea5a08eeee45c7e3df3f598355e Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 16 Dec 2024 22:19:27 +0700 Subject: [PATCH] feat: add dpo liger --- src/axolotl/core/trainer_builder.py | 327 +++++++++++++++++++++++++++- 1 file changed, 326 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73d9e0e65..903496154 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -14,17 +14,22 @@ import os import sys from abc import abstractmethod from collections import defaultdict +from contextlib import nullcontext from dataclasses import dataclass, field from functools import wraps from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Type, Union import torch +import torch.nn.functional as F import transformers from datasets import Dataset +from liger_kernel.chunked_loss.fused_linear_preference import ( + LigerFusedLinearPreferenceBase, +) from packaging import version from peft.optimizers import create_loraplus_optimizer -from torch import nn +from torch import amp, nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -1077,6 +1082,15 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): self.dataset_tags = dataset_tags self.optimizer = None + from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss + + self.liger_loss = LigerFusedLinearDPOLoss( + ignore_index=self.label_pad_token_id, + beta=self.beta, + compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None, + use_ref_model=not self.reference_free, + ) + def create_optimizer(self): if self.args.loraplus_lr_ratio is None: return super().create_optimizer() @@ -1180,6 +1194,309 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): # transformers<=4.46 return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the DPO loss and other metrics using Liger kernel.""" + # return super().get_batch_loss_metrics(model, batch, train_eval) + if not self.liger_loss: + raise ValueError("Liger loss not initialized") + + metrics = {} + + model_output = self.concatenated_forward(model, batch) + + # Get the lm_head weights and bias + lin_weight = model.lm_head.weight + lin_bias = getattr(model.lm_head, "bias", None) + + hidden_states = model_output["hidden_states"] + labels = model_output["labels"] + + if not self.reference_free: + # Adapted from DPO's compute_ref_log_probs + compte_ref_context_manager = ( + amp.autocast("cuda") + if self._peft_has_been_casted_to_bf16 + else nullcontext() + ) + with torch.no_grad(), compte_ref_context_manager: # type: ignore + if self.ref_model is None: + with self.null_ref_context(): + ref_model_output = self.concatenated_forward(self.model, batch) + ref_weight = self.model.lm_head.weight + ref_bias = getattr(self.model.lm_head, "bias", None) + + ref_hidden_states = ref_model_output["hidden_states"] + + else: + ref_model_output = self.concatenated_forward(self.ref_model, batch) + ref_weight = self.ref_model.lm_head.weight + ref_bias = getattr(self.ref_model.lm_head, "bias", None) + + ref_hidden_states = ref_model_output["hidden_states"] + ( + ref_chosen_logps, + ref_rejected_logps, + _ref_chosen_logits, + _ref_rejected_logits, + _ref_chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk=ref_hidden_states, + weight=ref_weight, + target_chunk=labels, + bias=ref_bias, + # ignore_index=ignore_index, + compute_nll_loss=False, + ) + + else: + ref_hidden_states = None + ref_weight = None + ref_bias = None + + # Compute loss using Liger kernel + loss, return_vars = self.liger_loss( + lin_weight=lin_weight, + _input=hidden_states, + target=labels, + bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't + ref_input=ref_hidden_states, + ref_weight=ref_weight, + ref_bias=ref_bias, + ) + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits_mean, + policy_rejected_logits_mean, + policy_nll_loss, + ) = return_vars + + # Calculate rewards + if not self.reference_free: + chosen_rewards = ( + self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach() + ) + rejected_rewards = ( + self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach() + ) + + else: + chosen_rewards = self.beta * policy_chosen_logps + rejected_rewards = self.beta * policy_rejected_logps + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics.update( + { + f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(), + f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(), + f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(), + f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards) + .mean() + .cpu(), + f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(), + f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(), + f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(), + f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(), + } + ) + + if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None: + metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu() + + # TODO: Handle use_weighting, aux_loss_enabled as in upstream + + return loss, metrics + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] + ): + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + + Overridden base function to return the hidden states and labels for the loss calculation. + """ + num_examples = batch["prompt_input_ids"].shape[0] # type: ignore + + concatenated_batch = self.concatenated_inputs( + batch, padding_value=self.padding_value + ) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add to get the hidden states for the loss + model_kwargs["output_hidden_states"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch[ + "pixel_attention_mask" + ] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + + prompt_input_ids = concatenated_batch["prompt_input_ids"] + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_input_ids = concatenated_batch["completion_input_ids"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + if self.is_encoder_decoder: + labels = completion_input_ids + labels[completion_attention_mask == 0] = self.label_pad_token_id + outputs = model( + input_ids=prompt_input_ids, + attention_mask=prompt_attention_mask, + labels=labels, # we need the labels for the logits to be returned + **model_kwargs, + ) + logits = outputs.logits + hidden_states = outputs.decoder_hidden_states[-1] + loss_mask = completion_attention_mask.bool() + else: + # Concatenate the prompt and completion inputs + input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) + attention_mask = torch.cat( + (prompt_attention_mask, completion_attention_mask), dim=1 + ) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) + + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + for i in range(attention_mask.size(0)): + first_one_idx = torch.nonzero(attention_mask[i])[0].item() + input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore + attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore + loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore + + # Get the first column idx that is all zeros and remove every column after that + empty_cols = torch.sum(attention_mask, dim=0) == 0 + first_empty_col = ( + torch.nonzero(empty_cols)[0].item() + if empty_cols.any() + else attention_mask.size(1) + ) + input_ids = input_ids[:, :first_empty_col] # type: ignore + attention_mask = attention_mask[:, :first_empty_col] # type: ignore + loss_mask = loss_mask[:, :first_empty_col] # type: ignore + + # Truncate right + if self.args.max_length is not None: + input_ids = input_ids[:, : self.args.max_length] + attention_mask = attention_mask[:, : self.args.max_length] + loss_mask = loss_mask[:, : self.args.max_length] + + # if self.use_num_logits_to_keep: + # # Compute num_logits_to_keep based on loss_mask pattern: + # # [[0, 0, 0, x, x, x, x], + # # [0, 0, 0, x, x, x, 0]] + # # ^ start computing logits from here ([:, -(7-3+1):]) + # first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + # num_logits_to_keep = loss_mask.shape[1] - first_compute_index + # model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label + + outputs = model( + input_ids=input_ids, attention_mask=attention_mask, **model_kwargs + ) + + # Offset the logits by one to align with the labels + logits = outputs.logits[:, :-1, :] + hidden_states = outputs.hidden_states[-1][:, :-1, :] + labels = input_ids[:, 1:].clone() + loss_mask = loss_mask[:, 1:].bool() + + # if self.use_num_logits_to_keep: + # # Align labels with logits + # # logits: -, -, [x2, x3, x4, x5, x6] + # # ^ --------- ^ after logits[:, :-1, :] + # # labels: [y0, y1, y2, y3, y4, y5, y6] + # # ^ --------- ^ with num_logits_to_keep=4, [:, -4:] + # # loss_mask: [0, 0, 0, 1, 1, 1, 1] + # labels = labels[:, -num_logits_to_keep:] + # loss_mask = loss_mask[:, -num_logits_to_keep:] + # hidden_states = hidden_states[:, -num_logits_to_keep:, :] + + if logits.shape[:2] != labels.shape[:2]: + # for llava, the returned logits include the image tokens (placed before the text tokens) + seq_len = labels.shape[1] + logits = logits[:, -seq_len:] + hidden_states = hidden_states[:, -seq_len:] + + # Compute the log probabilities of the labels + labels[ + ~loss_mask + ] = 0 # dummy token; we'll ignore the losses on these tokens later + per_token_logps = torch.gather( + logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + per_token_logps[~loss_mask] = 0 + all_logps = per_token_logps.sum(-1) + + output = {} + + if self.use_weighting: + with torch.no_grad(): + # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 + logprobs = F.log_softmax(logits, dim=-1) + weights_adjustment_factor = torch.logsumexp( + 2 * logprobs, dim=-1 + ) # same as sum(probs**2) in log space + per_token_logps_adjusted = per_token_logps - weights_adjustment_factor + all_weights = (per_token_logps_adjusted * loss_mask).sum( + -1 + ) / loss_mask.sum(-1) + chosen_weights = all_weights[:num_examples] + rejected_weights = all_weights[num_examples:] + output["policy_weights"] = torch.clamp( + torch.exp(chosen_weights + rejected_weights), max=1 + ) + + if self.args.rpo_alpha is not None: + # Only use the chosen logits for the RPO loss + chosen_logits = logits[:num_examples] + chosen_labels = labels[:num_examples] + + # Compute the log probabilities of the labels + output["nll_loss"] = F.cross_entropy( + torch.flatten(chosen_logits, end_dim=1), + torch.flatten(chosen_labels, end_dim=1), + ignore_index=0, + ) + + if self.loss_type == "ipo": + all_logps = all_logps / loss_mask.sum(-1) + + output["chosen_logps"] = all_logps[:num_examples] + output["rejected_logps"] = all_logps[num_examples:] + output["mean_chosen_logits"] = logits[:num_examples][ + loss_mask[:num_examples] + ].mean() + output["mean_rejected_logits"] = logits[num_examples:][ + loss_mask[num_examples:] + ].mean() + output["hidden_states"] = hidden_states + output["labels"] = labels + + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ @@ -2163,6 +2480,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.dpo_use_weighting is not None: training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting + report_to = [] + if self.cfg.use_wandb: + report_to.append("wandb") + if self.cfg.wandb_name: + training_args_kwargs["run_name"] = self.cfg.wandb_name + + training_args_kwargs["report_to"] = report_to + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg output_dir=self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size,