diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml new file mode 100644 index 000000000..79f5a2ba1 --- /dev/null +++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml @@ -0,0 +1,80 @@ +base_model: meta-llama/Llama-3.2-1B +# optionally might have model_type or tokenizer_type +model_type: LlamaForCausalLM +tokenizer_type: AutoTokenizer +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: true +load_in_4bit: false +strict: false + +datasets: + - path: mhenrichsen/alpaca_2k_test + type: alpaca + - path: mhenrichsen/alpaca_2k_test + type: alpaca +dataset_prepared_path: +val_set_size: 0.0 +output_dir: ./outputs/lora-out + +test_value: true + +sequence_len: 4096 +sample_packing: true +sample_packing_sequentially: true +curriculum_sampling: true +eval_sample_packing: false +pad_to_sequence_len: true + +adapter: lora +lora_model_dir: +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_linear: true +lora_fan_in_fan_out: +lora_modules_to_save: + - embed_tokens + - lm_head + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 2 +num_epochs: 4 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +train_on_inputs: false +group_by_length: false +bf16: auto +fp16: +tf32: false + +gradient_checkpointing: true +early_stopping_patience: +resume_from_checkpoint: +local_rank: +logging_steps: 1 +xformers_attention: +flash_attention: true +s2_attention: + +warmup_steps: 10 +evals_per_epoch: 4 +eval_table_size: +eval_max_new_tokens: 128 +saves_per_epoch: 1 +debug: +deepspeed: +weight_decay: 0.0 +fsdp: +fsdp_config: +special_tokens: + pad_token: <|end_of_text|> diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 9267dd040..f5679431a 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -112,6 +112,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai packing_efficiency_estimate=self.args.sample_packing_efficiency, batch_max_len=batch_max_len, batch_size=batch_size, + sequential=self.args.sample_packing_sequentially, drop_last=True, ) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index fbb363492..18843abb4 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -34,6 +34,12 @@ class AxolotlTrainingMixins: default=False, metadata={"help": "Use sample packing for efficient training."}, ) + sample_packing_sequentially: bool = field( + default=False, + metadata={ + "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." + }, + ) multipack_real_batches: bool = field( default=False, metadata={"help": "Use real batches for efficient training."}, diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 41095152e..ef47aca87 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -8,7 +8,7 @@ from typing import Any, Iterable, List, Union import numba import numpy as np -from torch.utils.data import BatchSampler, Sampler +from torch.utils.data import BatchSampler, Sampler, SequentialSampler from axolotl.utils.distributed import reduce_and_broadcast @@ -103,6 +103,55 @@ def allocate( return result, s, len(result) * c * n +@numba.njit +def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int): + """ + Sequential allocator that preserves example order + + Parameters: + - lengths: The lengths of all examples + - rank: The current rank (for distributed training) + - c: The capacity of each bin (maximum sequence length) + - n: Number of ranks + + Returns: + - result: List of batches for the current rank + - total_used: Number of actual example tokens + - total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) + """ + result = [] + total_used = 0 + + # First, do sequential packing into bins + all_bins = [] + current_bin = [0 for i in range(0)] # numba hint + remaining_capacity = c + + for idx, size in enumerate(lengths): + if size <= remaining_capacity: + # Example fits in current bin + current_bin.append(idx) + remaining_capacity -= size + total_used += size + else: + # Example doesn't fit, start a new bin + if current_bin: # Add non-empty bin to all_bins + all_bins.append(current_bin) + current_bin = [idx] + remaining_capacity = c - size + total_used += size + + # Add the last bin if not empty + if current_bin: + all_bins.append(current_bin) + + # Assign bins to ranks - each rank gets every n-th bin + for bin_idx in range(rank, len(all_bins), n): + result.append(all_bins[bin_idx]) + + return result, total_used, len(all_bins) * c + + class MultipackBatchSampler(BatchSampler): """Batch sampler class for multipack""" @@ -115,6 +164,7 @@ class MultipackBatchSampler(BatchSampler): packing_efficiency_estimate: float = 1.0, drop_last: bool = False, num_count_samples: int = 16, + sequential: bool = False, **kwargs, ): super().__init__(sampler, batch_size, drop_last) @@ -122,6 +172,7 @@ class MultipackBatchSampler(BatchSampler): self.batch_max_len = batch_max_len self.lengths: np.ndarray = lengths self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 + self.sequential = sequential assert isinstance(self.lengths, np.ndarray) @@ -136,6 +187,11 @@ class MultipackBatchSampler(BatchSampler): # the minimum packed dataset length across all ranks determined by a gather/broadcast self.len_across_ranks = None + if self.sequential and not isinstance(sampler, SequentialSampler): + LOG.warn( + "using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?" + ) + def set_epoch(self, epoch: int): self.epoch = epoch @@ -145,13 +201,23 @@ class MultipackBatchSampler(BatchSampler): lengths = self.lengths[indices] lengths_cumsum = np.cumsum(lengths) - batches, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=0, - c=self.batch_max_len, - n=1, - ) + if self.sequential: + LOG.debug("using sequential sample packing algorithm") + batches, total_used, total_slots = allocate_sequentially( + lengths=lengths, + rank=0, + c=self.batch_max_len, + n=1, + ) + else: + LOG.debug("using non-sequential sample packing algorithm") + batches, total_used, total_slots = allocate( + lengths=lengths, + lengths_cumsum=lengths_cumsum, + rank=0, + c=self.batch_max_len, + n=1, + ) batches = [ [ diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c7be33ab3..837d2ca69 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -192,6 +192,7 @@ class AxolotlInputConfig( sample_packing: bool | None = None sample_packing_group_size: int | None = 100_000 sample_packing_bin_size: int | None = 200 + sample_packing_sequentially: bool | None = None eval_sample_packing: bool | None = None pad_to_sequence_len: bool | None = None curriculum_sampling: bool | None = None diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d2b211bbc..646fb4c87 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -13,7 +13,7 @@ import torch import torch.cuda from accelerate.logging import get_logger from datasets import IterableDataset, disable_caching, enable_caching -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder @@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): else: sampler_batch_size = cfg.micro_batch_size batch_max_len = cfg.sequence_len + if cfg.curriculum_sampling: + sampler = SequentialSampler(train_dataset) + else: + sampler = RandomSampler(train_dataset) sampler = MultipackBatchSampler( - sampler=RandomSampler(train_dataset), + sampler=sampler, lengths=get_dataset_lengths(train_dataset), batch_size=sampler_batch_size, batch_max_len=batch_max_len, group_size=cfg.sample_packing_group_size, bin_size=cfg.sample_packing_bin_size, + sequential=cfg.sample_packing_sequentially, drop_last=True, ) diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 061b64b09..faba86931 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -38,8 +38,11 @@ class TestBatchedSamplerPacking: ], ) @pytest.mark.parametrize("max_seq_length", [4096, 512]) + @pytest.mark.parametrize("sequential", [True, False]) @enable_hf_offline - def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): + def test_packing( + self, batch_size, num_workers, tokenizer, max_seq_length, sequential + ): import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 dataset = load_dataset( @@ -75,6 +78,7 @@ class TestBatchedSamplerPacking: batch_max_len=max_seq_length, group_size=100000, bin_size=200, + sequential=sequential, ) loader = DataLoader(