wip
This commit is contained in:
@@ -11,7 +11,7 @@ import sys
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from functools import wraps, partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
@@ -53,6 +53,7 @@ from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
)
|
||||
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
@@ -538,37 +539,23 @@ class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
|
||||
position_ids=batch["position_ids"],
|
||||
).logits
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
|
||||
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
|
||||
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
|
||||
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
|
||||
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
|
||||
unpacked_logps = self.get_batch_logps(
|
||||
unpacked_logits,
|
||||
unpacked_labels,
|
||||
average_log_prob=self.loss_type == "ipo",
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
chosen_logps = unpacked_logps[::2]
|
||||
rejected_logps = unpacked_logps[1::2]
|
||||
chosen_logits = unpacked_logits[::2]
|
||||
rejected_logits = unpacked_logits[1::2]
|
||||
|
||||
|
||||
return super().concatenated_forward(model, batch)
|
||||
|
||||
@staticmethod
|
||||
def get_batch_logps_multipack(
|
||||
logits: torch.FloatTensor,
|
||||
labels: torch.LongTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
average_log_prob: bool = False,
|
||||
label_pad_token_id: int = -100,
|
||||
is_encoder_decoder: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
if is_encoder_decoder:
|
||||
raise ValueError("unhandled get_batch_logps_multipack(...) for is_encoder_decoder")
|
||||
if logits.shape[:-1] != labels.shape:
|
||||
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
||||
|
||||
labels = labels[:, 1:].clone()
|
||||
logits = logits[:, :-1, :]
|
||||
loss_mask = labels != label_pad_token_id
|
||||
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
labels[labels == label_pad_token_id] = 0
|
||||
|
||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||
|
||||
if average_log_prob:
|
||||
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
else:
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
@@ -178,6 +178,9 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDPODataCollatorWithPadding:
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaDataCollator:
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
|
||||
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
|
||||
if index >= nonzero_total:
|
||||
return False
|
||||
if pairs and (index // 2) >= (nonzero_total // 2):
|
||||
return False
|
||||
if pad_val and (data == pad_val).all(dim=0).all():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
|
||||
split_tensors = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user