Compare commits

...

3 Commits

Author SHA1 Message Date
Wing Lian
d46d7dfe30 wip 2024-02-01 00:28:16 -05:00
Wing Lian
047d9e1d5b helper utils 2024-01-31 12:49:29 -05:00
Wing Lian
88a0c05d2c wip 2024-01-31 12:07:39 -05:00
3 changed files with 268 additions and 133 deletions

View File

@@ -8,15 +8,17 @@ import importlib
import logging
import math
import sys
import typing
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import wraps
from functools import wraps, partial
from pathlib import Path
from typing import List, Optional, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -29,6 +31,7 @@ from transformers.trainer_utils import seed_worker
from trl import DPOTrainer
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -50,12 +53,20 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
)
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
if typing.TYPE_CHECKING:
# hacky, but recommended per https://github.com/python/mypy/issues/5837
_MixinTrainerBase = Trainer
else:
_MixinTrainerBase = object
LOG = logging.getLogger("axolotl.core.trainer_builder")
@@ -153,7 +164,142 @@ class AxolotlTrainingArguments(TrainingArguments):
)
class AxolotlTrainer(Trainer):
class AxolotlMultiPackTrainerMixin(_MixinTrainerBase): # type: ignore
"""Trainer Mixin class for dataloaders and samplers"""
args = None # type: AxolotlTrainingArguments
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
class AxolotlTrainer(AxolotlMultiPackTrainerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
@@ -227,135 +373,6 @@ class AxolotlTrainer(Trainer):
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
@@ -470,7 +487,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer):
class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
@@ -487,6 +504,59 @@ class AxolotlDPOTrainer(DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(self, feature, *args, **kwargs) -> Dict:
# check if dataset is already tokenized
if not self.is_encoder_decoder:
keys = [
"chosen_input_ids",
"chosen_attention_mask",
"chosen_labels",
"rejected_input_ids",
"rejected_attention_mask",
"rejected_labels",
]
if all(k in feature.keys() for k in keys):
return feature
else:
keys = [
"chosen_labels",
"rejected_labels",
"prompt_input_ids",
"prompt_attention_mask",
]
if all(k in feature.keys() for k in keys):
return feature
return super().tokenize_row(feature, *args, **kwargs)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
]:
all_logits = model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
position_ids=batch["position_ids"],
).logits
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
unpacked_logps = self.get_batch_logps(
unpacked_logits,
unpacked_labels,
average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = unpacked_logps[::2]
rejected_logps = unpacked_logps[1::2]
chosen_logits = unpacked_logits[::2]
rejected_logits = unpacked_logits[1::2]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
class TrainerBuilderBase(abc.ABC):
"""
@@ -1108,6 +1178,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
setattr(dpo_trainer, "use_dpo_data_collator", True)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)

View File

@@ -178,6 +178,9 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
class BatchSamplerDPODataCollatorWithPadding:
@dataclass
class MambaDataCollator:

View File

@@ -0,0 +1,61 @@
import torch
import torch.nn.functional as F
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
if index >= nonzero_total:
return False
if pairs and (index // 2) >= (nonzero_total // 2):
return False
if pad_val and (data == pad_val).all(dim=0).all():
return False
return True
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
split_tensors = []
counts = count_nonzero_sequences(cu_seqlens)
# Iterate over each batch
for i in range(tensor.size(0)):
seq_lens = cu_seqlens[i]
start_idx = 0
# Iterate over the cumulative sequence lengths
for j, end_idx in enumerate(seq_lens[1:]):
if end_idx == start_idx:
break
# Extract and pad the current sequence
current_seq = tensor[i, start_idx:end_idx]
keep = True
if keep_fn:
keep = keep_fn(current_seq, index=j, nonzero_total=counts[i])
if not keep:
continue
padding_size = max_seqlen - current_seq.size(0)
padded_seq = F.pad(current_seq, (0, 0) * (current_seq.dim() - 2) + (0, padding_size))
# Append the padded sequence to the list
split_tensors.append(padded_seq)
# Update start index for the next sequence
start_idx = end_idx
# Stack the padded tensors
return torch.stack(split_tensors, dim=0)
def count_nonzero_sequences(cu_seqlens: torch.Tensor) -> torch.LongTensor:
diffs = torch.diff(cu_seqlens, dim=1, prepend=torch.zeros(cu_seqlens.shape[0], 1, dtype=cu_seqlens.dtype))
valid_lengths = diffs != 0
counts = valid_lengths.sum(dim=1).long()
return counts
# Example usage
# Example tensor with dimensions [batch_size, seq_len, other_dimensions...]
# example_tensor = torch.randn(batch_size, seq_len, other_dimensions...)
# cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
# split_padded_tensor = split_and_pad_packed(example_tensor, cu_seqlens, max_seqlen)