Sequential sample packing (#2404) [skip ci]

* add sequential sample packing

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
DreamGenX
2025-03-31 21:48:20 +02:00
committed by GitHub
parent 7acf93b59f
commit 4d36ecc724
7 changed files with 174 additions and 11 deletions

View File

@@ -0,0 +1,80 @@
base_model: meta-llama/Llama-3.2-1B
# optionally might have model_type or tokenizer_type
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out
test_value: true
sequence_len: 4096
sample_packing: true
sample_packing_sequentially: true
curriculum_sampling: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -112,6 +112,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len, batch_max_len=batch_max_len,
batch_size=batch_size, batch_size=batch_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True, drop_last=True,
) )

View File

@@ -34,6 +34,12 @@ class AxolotlTrainingMixins:
default=False, default=False,
metadata={"help": "Use sample packing for efficient training."}, metadata={"help": "Use sample packing for efficient training."},
) )
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field( multipack_real_batches: bool = field(
default=False, default=False,
metadata={"help": "Use real batches for efficient training."}, metadata={"help": "Use real batches for efficient training."},

View File

@@ -8,7 +8,7 @@ from typing import Any, Iterable, List, Union
import numba import numba
import numpy as np import numpy as np
from torch.utils.data import BatchSampler, Sampler from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
@@ -103,6 +103,55 @@ def allocate(
return result, s, len(result) * c * n return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
for idx, size in enumerate(lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
total_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler): class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack""" """Batch sampler class for multipack"""
@@ -115,6 +164,7 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
drop_last: bool = False, drop_last: bool = False,
num_count_samples: int = 16, num_count_samples: int = 16,
sequential: bool = False,
**kwargs, **kwargs,
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
@@ -122,6 +172,7 @@ class MultipackBatchSampler(BatchSampler):
self.batch_max_len = batch_max_len self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
@@ -136,6 +187,11 @@ class MultipackBatchSampler(BatchSampler):
# the minimum packed dataset length across all ranks determined by a gather/broadcast # the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int): def set_epoch(self, epoch: int):
self.epoch = epoch self.epoch = epoch
@@ -145,13 +201,23 @@ class MultipackBatchSampler(BatchSampler):
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths) lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate( if self.sequential:
lengths=lengths, LOG.debug("using sequential sample packing algorithm")
lengths_cumsum=lengths_cumsum, batches, total_used, total_slots = allocate_sequentially(
rank=0, lengths=lengths,
c=self.batch_max_len, rank=0,
n=1, c=self.batch_max_len,
) n=1,
)
else:
LOG.debug("using non-sequential sample packing algorithm")
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches = [ batches = [
[ [

View File

@@ -192,6 +192,7 @@ class AxolotlInputConfig(
sample_packing: bool | None = None sample_packing: bool | None = None
sample_packing_group_size: int | None = 100_000 sample_packing_group_size: int | None = 100_000
sample_packing_bin_size: int | None = 200 sample_packing_bin_size: int | None = 200
sample_packing_sequentially: bool | None = None
eval_sample_packing: bool | None = None eval_sample_packing: bool | None = None
pad_to_sequence_len: bool | None = None pad_to_sequence_len: bool | None = None
curriculum_sampling: bool | None = None curriculum_sampling: bool | None = None

View File

@@ -13,7 +13,7 @@ import torch
import torch.cuda import torch.cuda
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import IterableDataset, disable_caching, enable_caching from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
else: else:
sampler_batch_size = cfg.micro_batch_size sampler_batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len batch_max_len = cfg.sequence_len
if cfg.curriculum_sampling:
sampler = SequentialSampler(train_dataset)
else:
sampler = RandomSampler(train_dataset)
sampler = MultipackBatchSampler( sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset), sampler=sampler,
lengths=get_dataset_lengths(train_dataset), lengths=get_dataset_lengths(train_dataset),
batch_size=sampler_batch_size, batch_size=sampler_batch_size,
batch_max_len=batch_max_len, batch_max_len=batch_max_len,
group_size=cfg.sample_packing_group_size, group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size, bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True, drop_last=True,
) )

View File

@@ -38,8 +38,11 @@ class TestBatchedSamplerPacking:
], ],
) )
@pytest.mark.parametrize("max_seq_length", [4096, 512]) @pytest.mark.parametrize("max_seq_length", [4096, 512])
@pytest.mark.parametrize("sequential", [True, False])
@enable_hf_offline @enable_hf_offline
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length): def test_packing(
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset( dataset = load_dataset(
@@ -75,6 +78,7 @@ class TestBatchedSamplerPacking:
batch_max_len=max_seq_length, batch_max_len=max_seq_length,
group_size=100000, group_size=100000,
bin_size=200, bin_size=200,
sequential=sequential,
) )
loader = DataLoader( loader = DataLoader(