wip
This commit is contained in:
@@ -11,7 +11,7 @@ import sys
|
|||||||
import typing
|
import typing
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps, partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
@@ -53,6 +53,7 @@ from axolotl.utils.schedulers import (
|
|||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
@@ -538,37 +539,23 @@ class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
|
|||||||
position_ids=batch["position_ids"],
|
position_ids=batch["position_ids"],
|
||||||
).logits
|
).logits
|
||||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
|
||||||
|
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
|
||||||
|
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
|
||||||
|
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
|
||||||
|
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
|
||||||
|
unpacked_logps = self.get_batch_logps(
|
||||||
|
unpacked_logits,
|
||||||
|
unpacked_labels,
|
||||||
|
average_log_prob=self.loss_type == "ipo",
|
||||||
|
is_encoder_decoder=self.is_encoder_decoder,
|
||||||
|
label_pad_token_id=self.label_pad_token_id,
|
||||||
|
)
|
||||||
|
chosen_logps = unpacked_logps[::2]
|
||||||
|
rejected_logps = unpacked_logps[1::2]
|
||||||
|
chosen_logits = unpacked_logits[::2]
|
||||||
|
rejected_logits = unpacked_logits[1::2]
|
||||||
|
|
||||||
|
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
||||||
return super().concatenated_forward(model, batch)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_batch_logps_multipack(
|
|
||||||
logits: torch.FloatTensor,
|
|
||||||
labels: torch.LongTensor,
|
|
||||||
position_ids: torch.LongTensor,
|
|
||||||
average_log_prob: bool = False,
|
|
||||||
label_pad_token_id: int = -100,
|
|
||||||
is_encoder_decoder: bool = False,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
if is_encoder_decoder:
|
|
||||||
raise ValueError("unhandled get_batch_logps_multipack(...) for is_encoder_decoder")
|
|
||||||
if logits.shape[:-1] != labels.shape:
|
|
||||||
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
|
||||||
|
|
||||||
labels = labels[:, 1:].clone()
|
|
||||||
logits = logits[:, :-1, :]
|
|
||||||
loss_mask = labels != label_pad_token_id
|
|
||||||
|
|
||||||
# dummy token; we'll ignore the losses on these tokens later
|
|
||||||
labels[labels == label_pad_token_id] = 0
|
|
||||||
|
|
||||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
|
||||||
|
|
||||||
if average_log_prob:
|
|
||||||
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
|
||||||
else:
|
|
||||||
return (per_token_logps * loss_mask).sum(-1)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
|
|||||||
@@ -178,6 +178,9 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
features = [chunked_data]
|
features = [chunked_data]
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchSamplerDPODataCollatorWithPadding:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MambaDataCollator:
|
class MambaDataCollator:
|
||||||
|
|||||||
@@ -1,6 +1,18 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
|
||||||
|
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
|
||||||
|
if index >= nonzero_total:
|
||||||
|
return False
|
||||||
|
if pairs and (index // 2) >= (nonzero_total // 2):
|
||||||
|
return False
|
||||||
|
if pad_val and (data == pad_val).all(dim=0).all():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
|
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
|
||||||
split_tensors = []
|
split_tensors = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user