Sequential sample packing (#2404) [skip ci]
* add sequential sample packing * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
test_value: true
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
sample_packing_sequentially: true
|
||||||
|
curriculum_sampling: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
@@ -112,6 +112,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
|
|||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
sequential=self.args.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,12 @@ class AxolotlTrainingMixins:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
)
|
)
|
||||||
|
sample_packing_sequentially: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import Any, Iterable, List, Union
|
|||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler, Sampler
|
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||||
|
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
|
|
||||||
@@ -103,6 +103,55 @@ def allocate(
|
|||||||
return result, s, len(result) * c * n
|
return result, s, len(result) * c * n
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
|
||||||
|
"""
|
||||||
|
Sequential allocator that preserves example order
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- lengths: The lengths of all examples
|
||||||
|
- rank: The current rank (for distributed training)
|
||||||
|
- c: The capacity of each bin (maximum sequence length)
|
||||||
|
- n: Number of ranks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- result: List of batches for the current rank
|
||||||
|
- total_used: Number of actual example tokens
|
||||||
|
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
total_used = 0
|
||||||
|
|
||||||
|
# First, do sequential packing into bins
|
||||||
|
all_bins = []
|
||||||
|
current_bin = [0 for i in range(0)] # numba hint
|
||||||
|
remaining_capacity = c
|
||||||
|
|
||||||
|
for idx, size in enumerate(lengths):
|
||||||
|
if size <= remaining_capacity:
|
||||||
|
# Example fits in current bin
|
||||||
|
current_bin.append(idx)
|
||||||
|
remaining_capacity -= size
|
||||||
|
total_used += size
|
||||||
|
else:
|
||||||
|
# Example doesn't fit, start a new bin
|
||||||
|
if current_bin: # Add non-empty bin to all_bins
|
||||||
|
all_bins.append(current_bin)
|
||||||
|
current_bin = [idx]
|
||||||
|
remaining_capacity = c - size
|
||||||
|
total_used += size
|
||||||
|
|
||||||
|
# Add the last bin if not empty
|
||||||
|
if current_bin:
|
||||||
|
all_bins.append(current_bin)
|
||||||
|
|
||||||
|
# Assign bins to ranks - each rank gets every n-th bin
|
||||||
|
for bin_idx in range(rank, len(all_bins), n):
|
||||||
|
result.append(all_bins[bin_idx])
|
||||||
|
|
||||||
|
return result, total_used, len(all_bins) * c
|
||||||
|
|
||||||
|
|
||||||
class MultipackBatchSampler(BatchSampler):
|
class MultipackBatchSampler(BatchSampler):
|
||||||
"""Batch sampler class for multipack"""
|
"""Batch sampler class for multipack"""
|
||||||
|
|
||||||
@@ -115,6 +164,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
packing_efficiency_estimate: float = 1.0,
|
packing_efficiency_estimate: float = 1.0,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
num_count_samples: int = 16,
|
num_count_samples: int = 16,
|
||||||
|
sequential: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
@@ -122,6 +172,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.batch_max_len = batch_max_len
|
self.batch_max_len = batch_max_len
|
||||||
self.lengths: np.ndarray = lengths
|
self.lengths: np.ndarray = lengths
|
||||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
self.sequential = sequential
|
||||||
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
@@ -136,6 +187,11 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
# the minimum packed dataset length across all ranks determined by a gather/broadcast
|
||||||
self.len_across_ranks = None
|
self.len_across_ranks = None
|
||||||
|
|
||||||
|
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||||
|
LOG.warn(
|
||||||
|
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
|
||||||
|
)
|
||||||
|
|
||||||
def set_epoch(self, epoch: int):
|
def set_epoch(self, epoch: int):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
|
||||||
@@ -145,13 +201,23 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
lengths = self.lengths[indices]
|
lengths = self.lengths[indices]
|
||||||
lengths_cumsum = np.cumsum(lengths)
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
batches, total_used, total_slots = allocate(
|
if self.sequential:
|
||||||
lengths=lengths,
|
LOG.debug("using sequential sample packing algorithm")
|
||||||
lengths_cumsum=lengths_cumsum,
|
batches, total_used, total_slots = allocate_sequentially(
|
||||||
rank=0,
|
lengths=lengths,
|
||||||
c=self.batch_max_len,
|
rank=0,
|
||||||
n=1,
|
c=self.batch_max_len,
|
||||||
)
|
n=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.debug("using non-sequential sample packing algorithm")
|
||||||
|
batches, total_used, total_slots = allocate(
|
||||||
|
lengths=lengths,
|
||||||
|
lengths_cumsum=lengths_cumsum,
|
||||||
|
rank=0,
|
||||||
|
c=self.batch_max_len,
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
|
||||||
batches = [
|
batches = [
|
||||||
[
|
[
|
||||||
|
|||||||
@@ -192,6 +192,7 @@ class AxolotlInputConfig(
|
|||||||
sample_packing: bool | None = None
|
sample_packing: bool | None = None
|
||||||
sample_packing_group_size: int | None = 100_000
|
sample_packing_group_size: int | None = 100_000
|
||||||
sample_packing_bin_size: int | None = 200
|
sample_packing_bin_size: int | None = 200
|
||||||
|
sample_packing_sequentially: bool | None = None
|
||||||
eval_sample_packing: bool | None = None
|
eval_sample_packing: bool | None = None
|
||||||
pad_to_sequence_len: bool | None = None
|
pad_to_sequence_len: bool | None = None
|
||||||
curriculum_sampling: bool | None = None
|
curriculum_sampling: bool | None = None
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import torch
|
|||||||
import torch.cuda
|
import torch.cuda
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import IterableDataset, disable_caching, enable_caching
|
from datasets import IterableDataset, disable_caching, enable_caching
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
else:
|
else:
|
||||||
sampler_batch_size = cfg.micro_batch_size
|
sampler_batch_size = cfg.micro_batch_size
|
||||||
batch_max_len = cfg.sequence_len
|
batch_max_len = cfg.sequence_len
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
sampler = SequentialSampler(train_dataset)
|
||||||
|
else:
|
||||||
|
sampler = RandomSampler(train_dataset)
|
||||||
sampler = MultipackBatchSampler(
|
sampler = MultipackBatchSampler(
|
||||||
sampler=RandomSampler(train_dataset),
|
sampler=sampler,
|
||||||
lengths=get_dataset_lengths(train_dataset),
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
batch_size=sampler_batch_size,
|
batch_size=sampler_batch_size,
|
||||||
batch_max_len=batch_max_len,
|
batch_max_len=batch_max_len,
|
||||||
group_size=cfg.sample_packing_group_size,
|
group_size=cfg.sample_packing_group_size,
|
||||||
bin_size=cfg.sample_packing_bin_size,
|
bin_size=cfg.sample_packing_bin_size,
|
||||||
|
sequential=cfg.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -38,8 +38,11 @@ class TestBatchedSamplerPacking:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
@pytest.mark.parametrize("max_seq_length", [4096, 512])
|
||||||
|
@pytest.mark.parametrize("sequential", [True, False])
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
def test_packing(
|
||||||
|
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
|
||||||
|
):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
@@ -75,6 +78,7 @@ class TestBatchedSamplerPacking:
|
|||||||
batch_max_len=max_seq_length,
|
batch_max_len=max_seq_length,
|
||||||
group_size=100000,
|
group_size=100000,
|
||||||
bin_size=200,
|
bin_size=200,
|
||||||
|
sequential=sequential,
|
||||||
)
|
)
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
|
|||||||
Reference in New Issue
Block a user