This commit is contained in:
Wing Lian
2024-02-01 00:28:16 -05:00
parent 047d9e1d5b
commit d46d7dfe30
3 changed files with 33 additions and 31 deletions

View File

@@ -11,7 +11,7 @@ import sys
import typing
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import wraps
from functools import wraps, partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union
@@ -53,6 +53,7 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
)
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
try:
import torch._dynamo # pylint: disable=ungrouped-imports
@@ -538,37 +539,23 @@ class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
position_ids=batch["position_ids"],
).logits
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
unpacked_logps = self.get_batch_logps(
unpacked_logits,
unpacked_labels,
average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = unpacked_logps[::2]
rejected_logps = unpacked_logps[1::2]
chosen_logits = unpacked_logits[::2]
rejected_logits = unpacked_logits[1::2]
return super().concatenated_forward(model, batch)
@staticmethod
def get_batch_logps_multipack(
logits: torch.FloatTensor,
labels: torch.LongTensor,
position_ids: torch.LongTensor,
average_log_prob: bool = False,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
if is_encoder_decoder:
raise ValueError("unhandled get_batch_logps_multipack(...) for is_encoder_decoder")
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
# dummy token; we'll ignore the losses on these tokens later
labels[labels == label_pad_token_id] = 0
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
class TrainerBuilderBase(abc.ABC):

View File

@@ -178,6 +178,9 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
class BatchSamplerDPODataCollatorWithPadding:
@dataclass
class MambaDataCollator:

View File

@@ -1,6 +1,18 @@
import torch
import torch.nn.functional as F
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
if index >= nonzero_total:
return False
if pairs and (index // 2) >= (nonzero_total // 2):
return False
if pad_val and (data == pad_val).all(dim=0).all():
return False
return True
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
split_tensors = []