diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 96054dc50..8611896db 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -8,15 +8,17 @@ import importlib import logging import math import sys +import typing from abc import abstractmethod from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import List, Optional, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch import transformers from datasets import Dataset +from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( @@ -29,6 +31,7 @@ from transformers.trainer_utils import seed_worker from trl import DPOTrainer from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -56,6 +59,13 @@ try: except ImportError: pass +if typing.TYPE_CHECKING: + # hacky, but recommended per https://github.com/python/mypy/issues/5837 + _MixinTrainerBase = Trainer +else: + _MixinTrainerBase = object + + LOG = logging.getLogger("axolotl.core.trainer_builder") @@ -153,7 +163,142 @@ class AxolotlTrainingArguments(TrainingArguments): ) -class AxolotlTrainer(Trainer): +class AxolotlMultiPackTrainerMixin(_MixinTrainerBase): # type: ignore + """Trainer Mixin class for dataloaders and samplers""" + + args = None # type: AxolotlTrainingArguments + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.args.sample_packing and not self.args.pretraining: + return MultipackBatchSampler( + RandomSampler(self.train_dataset), + self.args.train_batch_size, + drop_last=True, + batch_max_len=self._train_batch_size * self.args.max_seq_length, + lengths=get_dataset_lengths(self.train_dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + ) + return super()._get_train_sampler() + + def get_train_dataloader(self) -> DataLoader: + if self.args.sample_packing and not self.args.pretraining: + train_dataset = self.train_dataset + if "length" in train_dataset.features.keys(): + train_dataset = train_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + sampler = self._get_train_sampler() + if isinstance(sampler, BatchSampler): + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(train_dataset, **dataloader_params) + ) + return super().get_train_dataloader() + + def _get_eval_sampler( + self, eval_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.sample_packing and self.args.eval_sample_packing is not False: + return MultipackBatchSampler( + SequentialSampler(eval_dataset), + self.args.per_device_eval_batch_size, + drop_last=True, + batch_max_len=self.args.eval_batch_size * self.args.max_seq_length, + lengths=get_dataset_lengths(eval_dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + ) + return super()._get_eval_sampler(eval_dataset) + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + if self.args.sample_packing and self.args.eval_sample_packing is False: + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.eval_data_collator + ) + dataloader = super().get_eval_dataloader(eval_dataset) + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.train_data_collator + ) + return dataloader + + if self.args.sample_packing and self.args.eval_sample_packing is not False: + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + eval_dataset = eval_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + if isinstance(eval_sampler, BatchSampler): + dataloader_params["batch_sampler"] = eval_sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = eval_sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(eval_dataset, **dataloader_params) + ) + + return super().get_eval_dataloader(eval_dataset) + + def _get_bench_sampler( + self, bench_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size <= 1: + return SequentialSampler(bench_dataset) + return None + + def get_bench_dataloader( + self, + bench_dataset: Dataset, + ) -> DataLoader: + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": self.bench_data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if not isinstance(bench_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + return DataLoader(bench_dataset, **dataloader_params) + # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + + +class AxolotlTrainer(AxolotlMultiPackTrainerMixin, Trainer): """ Extend the base Trainer for axolotl helpers """ @@ -227,135 +372,6 @@ class AxolotlTrainer(Trainer): return self.lr_scheduler - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and not self.args.pretraining: - return MultipackBatchSampler( - RandomSampler(self.train_dataset), - self.args.train_batch_size, - drop_last=True, - batch_max_len=self._train_batch_size * self.args.max_seq_length, - lengths=get_dataset_lengths(self.train_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - ) - return super()._get_train_sampler() - - def _get_eval_sampler( - self, eval_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - return MultipackBatchSampler( - SequentialSampler(eval_dataset), - self.args.per_device_eval_batch_size, - drop_last=True, - batch_max_len=self.args.eval_batch_size * self.args.max_seq_length, - lengths=get_dataset_lengths(eval_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - ) - return super()._get_eval_sampler(eval_dataset) - - def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing and not self.args.pretraining: - train_dataset = self.train_dataset - if "length" in train_dataset.features.keys(): - train_dataset = train_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - sampler = self._get_train_sampler() - if isinstance(sampler, BatchSampler): - dataloader_params["batch_sampler"] = sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - if self.args.sample_packing and self.args.eval_sample_packing is False: - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.eval_data_collator - ) - dataloader = super().get_eval_dataloader(eval_dataset) - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.train_data_collator - ) - return dataloader - - if self.args.sample_packing and self.args.eval_sample_packing is not False: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - - eval_sampler = self._get_eval_sampler(eval_dataset) - eval_dataset = eval_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - if isinstance(eval_sampler, BatchSampler): - dataloader_params["batch_sampler"] = eval_sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = eval_sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(eval_dataset, **dataloader_params) - ) - - return super().get_eval_dataloader(eval_dataset) - - def _get_bench_sampler( - self, bench_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size <= 1: - return SequentialSampler(bench_dataset) - return None - - def get_bench_dataloader( - self, - bench_dataset: Dataset, - ) -> DataLoader: - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": self.bench_data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - if not isinstance(bench_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - return DataLoader(bench_dataset, **dataloader_params) - # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) - def compute_loss(self, model, inputs, return_outputs=False): # use one's weighted cross entropy loss calc # if self.args.sample_packing: @@ -470,7 +486,7 @@ class ReLoRATrainer(AxolotlTrainer): return self.lr_scheduler -class AxolotlDPOTrainer(DPOTrainer): +class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer): """ Extend the base DPOTrainer for axolotl helpers """ @@ -487,6 +503,73 @@ class AxolotlDPOTrainer(DPOTrainer): return super().push_to_hub(*args, **kwargs) + def tokenize_row(self, feature, *args, **kwargs) -> Dict: + # check if dataset is already tokenized + if not self.is_encoder_decoder: + keys = [ + "chosen_input_ids", + "chosen_attention_mask", + "chosen_labels", + "rejected_input_ids", + "rejected_attention_mask", + "rejected_labels", + ] + if all(k in feature.keys() for k in keys): + return feature + else: + keys = [ + "chosen_labels", + "rejected_labels", + "prompt_input_ids", + "prompt_attention_mask", + ] + if all(k in feature.keys() for k in keys): + return feature + return super().tokenize_row(feature, *args, **kwargs) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[ + torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor + ]: + all_logits = model( + batch["input_ids"], + attention_mask=batch["attention_mask"], + position_ids=batch["position_ids"], + ).logits + cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"]) + + + return super().concatenated_forward(model, batch) + + @staticmethod + def get_batch_logps_multipack( + logits: torch.FloatTensor, + labels: torch.LongTensor, + position_ids: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + if is_encoder_decoder: + raise ValueError("unhandled get_batch_logps_multipack(...) for is_encoder_decoder") + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + class TrainerBuilderBase(abc.ABC): """ @@ -1108,6 +1191,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): callbacks=self.get_callbacks(), **dpo_trainer_kwargs, ) + setattr(dpo_trainer, "use_dpo_data_collator", True) dpo_trainer = self.hook_post_create_trainer(dpo_trainer) for callback in self.get_post_trainer_create_callbacks(dpo_trainer): dpo_trainer.add_callback(callback)