From d46d7dfe30d2603558fc65be89023e37cbc46e2b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 1 Feb 2024 00:28:16 -0500 Subject: [PATCH] wip --- src/axolotl/core/trainer_builder.py | 49 +++++++++++------------------ src/axolotl/utils/collators.py | 3 ++ src/axolotl/utils/tensors.py | 12 +++++++ 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8611896db..62bbb2b25 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -11,7 +11,7 @@ import sys import typing from abc import abstractmethod from dataclasses import dataclass, field -from functools import wraps +from functools import wraps, partial from pathlib import Path from typing import Dict, List, Optional, Tuple, Type, Union @@ -53,6 +53,7 @@ from axolotl.utils.schedulers import ( get_cosine_schedule_with_min_lr, get_cosine_schedule_with_quadratic_warmup, ) +from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed try: import torch._dynamo # pylint: disable=ungrouped-imports @@ -538,37 +539,23 @@ class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer): position_ids=batch["position_ids"], ).logits cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"]) + logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True) + unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn) + labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True) + unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn) + unpacked_logps = self.get_batch_logps( + unpacked_logits, + unpacked_labels, + average_log_prob=self.loss_type == "ipo", + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + chosen_logps = unpacked_logps[::2] + rejected_logps = unpacked_logps[1::2] + chosen_logits = unpacked_logits[::2] + rejected_logits = unpacked_logits[1::2] - - return super().concatenated_forward(model, batch) - - @staticmethod - def get_batch_logps_multipack( - logits: torch.FloatTensor, - labels: torch.LongTensor, - position_ids: torch.LongTensor, - average_log_prob: bool = False, - label_pad_token_id: int = -100, - is_encoder_decoder: bool = False, - ) -> torch.FloatTensor: - if is_encoder_decoder: - raise ValueError("unhandled get_batch_logps_multipack(...) for is_encoder_decoder") - if logits.shape[:-1] != labels.shape: - raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") - - labels = labels[:, 1:].clone() - logits = logits[:, :-1, :] - loss_mask = labels != label_pad_token_id - - # dummy token; we'll ignore the losses on these tokens later - labels[labels == label_pad_token_id] = 0 - - per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) - - if average_log_prob: - return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) - else: - return (per_token_logps * loss_mask).sum(-1) + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) class TrainerBuilderBase(abc.ABC): diff --git a/src/axolotl/utils/collators.py b/src/axolotl/utils/collators.py index da1ee392f..4d2804fb1 100644 --- a/src/axolotl/utils/collators.py +++ b/src/axolotl/utils/collators.py @@ -178,6 +178,9 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features = [chunked_data] return super().__call__(features, return_tensors=return_tensors) +@dataclass +class BatchSamplerDPODataCollatorWithPadding: + @dataclass class MambaDataCollator: diff --git a/src/axolotl/utils/tensors.py b/src/axolotl/utils/tensors.py index 2198bd31d..aadbf5696 100644 --- a/src/axolotl/utils/tensors.py +++ b/src/axolotl/utils/tensors.py @@ -1,6 +1,18 @@ import torch import torch.nn.functional as F + +def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False): + # pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask) + if index >= nonzero_total: + return False + if pairs and (index // 2) >= (nonzero_total // 2): + return False + if pad_val and (data == pad_val).all(dim=0).all(): + return False + return True + + def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None): split_tensors = []