Sequential sample packing (#2404) [skip ci]
* add sequential sample packing * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
@@ -0,0 +1,80 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
test_value: true
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
sample_packing_sequentially: true
|
||||
curriculum_sampling: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -112,6 +112,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_size=batch_size,
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,6 +34,12 @@ class AxolotlTrainingMixins:
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
sample_packing_sequentially: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
|
||||
@@ -103,6 +103,55 @@ def allocate(
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||
"""
|
||||
Sequential allocator that preserves example order
|
||||
|
||||
Parameters:
|
||||
- lengths: The lengths of all examples
|
||||
- rank: The current rank (for distributed training)
|
||||
- c: The capacity of each bin (maximum sequence length)
|
||||
- n: Number of ranks
|
||||
|
||||
Returns:
|
||||
- result: List of batches for the current rank
|
||||
- total_used: Number of actual example tokens
|
||||
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||
"""
|
||||
result = []
|
||||
total_used = 0
|
||||
|
||||
# First, do sequential packing into bins
|
||||
all_bins = []
|
||||
current_bin = [0 for i in range(0)] # numba hint
|
||||
remaining_capacity = c
|
||||
|
||||
for idx, size in enumerate(lengths):
|
||||
if size <= remaining_capacity:
|
||||
# Example fits in current bin
|
||||
current_bin.append(idx)
|
||||
remaining_capacity -= size
|
||||
total_used += size
|
||||
else:
|
||||
# Example doesn't fit, start a new bin
|
||||
if current_bin: # Add non-empty bin to all_bins
|
||||
all_bins.append(current_bin)
|
||||
current_bin = [idx]
|
||||
remaining_capacity = c - size
|
||||
total_used += size
|
||||
|
||||
# Add the last bin if not empty
|
||||
if current_bin:
|
||||
all_bins.append(current_bin)
|
||||
|
||||
# Assign bins to ranks - each rank gets every n-th bin
|
||||
for bin_idx in range(rank, len(all_bins), n):
|
||||
result.append(all_bins[bin_idx])
|
||||
|
||||
return result, total_used, len(all_bins) * c
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""Batch sampler class for multipack"""
|
||||
|
||||
@@ -115,6 +164,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
drop_last: bool = False,
|
||||
num_count_samples: int = 16,
|
||||
sequential: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
@@ -122,6 +172,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.sequential = sequential
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
@@ -136,6 +187,11 @@ class MultipackBatchSampler(BatchSampler):
|
||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||
self.len_across_ranks = None
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warn(
|
||||
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
@@ -145,13 +201,23 @@ class MultipackBatchSampler(BatchSampler):
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
if self.sequential:
|
||||
LOG.debug("using sequential sample packing algorithm")
|
||||
batches, total_used, total_slots = allocate_sequentially(
|
||||
lengths=lengths,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
else:
|
||||
LOG.debug("using non-sequential sample packing algorithm")
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [
|
||||
[
|
||||
|
||||
@@ -192,6 +192,7 @@ class AxolotlInputConfig(
|
||||
sample_packing: bool | None = None
|
||||
sample_packing_group_size: int | None = 100_000
|
||||
sample_packing_bin_size: int | None = 200
|
||||
sample_packing_sequentially: bool | None = None
|
||||
eval_sample_packing: bool | None = None
|
||||
pad_to_sequence_len: bool | None = None
|
||||
curriculum_sampling: bool | None = None
|
||||
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
else:
|
||||
sampler_batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
if cfg.curriculum_sampling:
|
||||
sampler = SequentialSampler(train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
sampler=sampler,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
batch_size=sampler_batch_size,
|
||||
batch_max_len=batch_max_len,
|
||||
group_size=cfg.sample_packing_group_size,
|
||||
bin_size=cfg.sample_packing_bin_size,
|
||||
sequential=cfg.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -38,8 +38,11 @@ class TestBatchedSamplerPacking:
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||
@pytest.mark.parametrize("sequential", [True, False])
|
||||
@enable_hf_offline
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
def test_packing(
|
||||
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
|
||||
):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = load_dataset(
|
||||
@@ -75,6 +78,7 @@ class TestBatchedSamplerPacking:
|
||||
batch_max_len=max_seq_length,
|
||||
group_size=100000,
|
||||
bin_size=200,
|
||||
sequential=sequential,
|
||||
)
|
||||
|
||||
loader = DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user