diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py index a9b1f5e89..724c33154 100644 --- a/src/axolotl/utils/dataloader.py +++ b/src/axolotl/utils/dataloader.py @@ -1,10 +1,10 @@ # pylint: skip-file -from typing import Any, Callable, List +from typing import Any, Callable, List, Union import numba import numpy as np -from torch.utils.data import DistributedSampler +from torch.utils.data import DistributedSampler, Sampler @numba.njit @@ -108,7 +108,7 @@ class MultipackDistributedDataloader: collate_fn: Callable, seq_max_length: int = 2048, batch_size: int = 1, - sampler: DistributedSampler = None, + sampler: Union[Sampler, DistributedSampler] = None, seed: int = 0, ): # Dataset diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6a39eede1..227960e6e 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -134,7 +134,9 @@ def load_model( LOG.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() - if cfg.is_llama_derived_model and cfg.max_packed_sequence_len: + if cfg.is_llama_derived_model and ( + cfg.max_packed_sequence_len or cfg.sample_packing + ): from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask LOG.info("patching _expand_mask") diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 4a2fc1e6d..722aeffa7 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -14,7 +14,7 @@ import transformers from datasets import Dataset from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DataLoader, RandomSampler from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names @@ -106,13 +106,18 @@ class AxolotlTrainer(Trainer): self, eval_dataset: Optional[Dataset] = None ) -> Union[DataLoader, MultipackDistributedDataloader]: if self.args.sample_packing: + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset + ) eval_sampler = self._get_eval_sampler(eval_dataset) - return MultipackDistributedDataloader( - eval_dataset, - batch_size=self.args.per_device_eval_batch_size, - seq_max_length=self.args.max_seq_length, - collate_fn=self.data_collator, - sampler=eval_sampler, + return self.accelerator.prepare( + MultipackDistributedDataloader( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + seq_max_length=self.args.max_seq_length, + collate_fn=self.data_collator, + sampler=eval_sampler, + ) ) return super().get_eval_dataloader(eval_dataset) @@ -153,8 +158,8 @@ def add_position_ids(sample): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.sample_packing: - train_dataset = train_dataset.map(add_position_ids) - eval_dataset = eval_dataset.map(add_position_ids) + # train_dataset = train_dataset.map(add_position_ids) + # eval_dataset = eval_dataset.map(add_position_ids) if cfg.sample_packing_eff_est: total_num_tokens = sum(len(s["input_ids"]) for s in train_dataset) total_num_steps = math.ceil( @@ -165,12 +170,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): / cfg.batch_size ) else: - sampler = DistributedSampler( - train_dataset, - num_replicas=1, - rank=0, - seed=cfg.seed or 42, - ) + sampler = RandomSampler(train_dataset) data_loader = MultipackDistributedDataloader( train_dataset, batch_size=cfg.micro_batch_size,