Compare commits

...

5 Commits

Author SHA1 Message Date
Dan Saunders
e910e3e164 Revert "Multipack parallel bin packing (#2631)"
This reverts commit 8e4158cc0b.
2025-05-09 17:33:31 +00:00
Wing Lian
0f3587174d swap tinymodels that have safetensors for some ci tests (#2641) 2025-05-07 15:06:07 -04:00
xzuyn
25e6c5f9bd Add CAME Optimizer (#2385) 2025-05-07 10:31:46 -04:00
NanoCode012
32f51bca35 fix(doc): clarify instruction to delinearize llama4 similar to cli doc (#2644) [skip ci] 2025-05-07 10:29:47 -04:00
NanoCode012
9daa04da90 Fix: improve error message on failed dataset load (#2637) [skip ci]
* fix(log): clarify error on dataset loading failed

* fix: add path for easy tracking of broken config

* fix: improve error message based on pr feedback
2025-05-07 10:29:05 -04:00
24 changed files with 340 additions and 310 deletions

View File

@@ -18,9 +18,96 @@ jobs:
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest: pytest:
name: PyTest name: PyTest
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [preload-cache]
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 2 max-parallel: 2

View File

@@ -612,6 +612,7 @@ lr_div_factor: # Learning rate div factor
# - optimi_adamw # - optimi_adamw
# - ao_adamw_8bit # - ao_adamw_8bit
# - ao_adamw_fp8 # - ao_adamw_fp8
# - came_pytorch
optimizer: optimizer:
# Dictionary of arguments to pass to the optimizer # Dictionary of arguments to pass to the optimizer
optim_args: optim_args:

View File

@@ -34,3 +34,5 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
```bash ```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
``` ```
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.

View File

@@ -11,6 +11,7 @@ liger-kernel==0.5.9
packaging==23.2 packaging==23.2
huggingface_hub==0.31.0
peft==0.15.2 peft==0.15.2
transformers==4.51.3 transformers==4.51.3
tokenizers>=0.21.1 tokenizers>=0.21.1

View File

@@ -142,6 +142,7 @@ extras_require = {
"apollo-torch", "apollo-torch",
"lomo-optim==0.1.1", "lomo-optim==0.1.1",
"torch-optimi==0.2.1", "torch-optimi==0.2.1",
"came_pytorch==0.1.3",
], ],
"ray": [ "ray": [
"ray[train]", "ray[train]",

View File

@@ -708,6 +708,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
optimizer_cls = ADOPT optimizer_cls = ADOPT
adam_kwargs["decouple"] = True adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs) optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config # Parse any additional optimizer args from config
if self.cfg.optim_args: if self.cfg.optim_args:

View File

@@ -114,8 +114,6 @@ class AxolotlTrainer(
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len, batch_max_len=batch_max_len,
batch_size=batch_size, batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially, sequential=self.args.sample_packing_sequentially,
drop_last=True, drop_last=True,
) )

View File

@@ -2,6 +2,7 @@
import importlib import importlib
import inspect import inspect
import logging
import os import os
import signal import signal
import sys import sys
@@ -12,7 +13,6 @@ from typing import Any, Dict
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.errors import OfflineModeIsEnabled
@@ -42,7 +42,7 @@ try:
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
LOG = get_logger(__name__) LOG = logging.getLogger(__name__)
def setup_model_and_tokenizer( def setup_model_and_tokenizer(
@@ -63,7 +63,6 @@ def setup_model_and_tokenizer(
# Load tokenizer # Load tokenizer
LOG.debug( LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
) )
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)

View File

@@ -281,6 +281,10 @@ def load_dataset_w_config(
**load_ds_kwargs, **load_ds_kwargs,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
)
return ds return ds

View File

@@ -1,15 +1,36 @@
"""custom checkpointing utils""" """custom checkpointing utils"""
import importlib
from functools import partial from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import ( from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer, Unsloth_Offloaded_Gradient_Checkpointer,
) )
transformers_version = version.parse(importlib.metadata.version("transformers"))
if transformers_version > version.parse("4.51.3"):
from transformers.modeling_layers import GradientCheckpointingLayer
def uses_gc_layers(decoder_layer):
return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)
else:
def uses_gc_layers(_):
return False
def hf_grad_checkpoint_offload_wrapper( def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply( return Unsloth_Offloaded_Gradient_Checkpointer.apply(
( (
decoder_layer.func.__self__ decoder_layer.func.__self__

View File

@@ -1,13 +1,10 @@
# pylint: skip-file
""" """
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences Multipack Batch Sampler
into fixed-capacity batches to optimize memory usage and training throughput.
""" """
import logging import logging
import math import math
from concurrent.futures import ProcessPoolExecutor from typing import Any, Iterable, List, Union
from multiprocessing import cpu_count
from typing import Iterable, Union
import numba import numba
import numpy as np import numpy as np
@@ -16,39 +13,26 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO) LOG.setLevel(logging.INFO)
@numba.njit @numba.njit
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): def ffd_check(a: np.ndarray, c: int, n: int):
""" # First-fit-decreasing bin packing
First-fit-decreasing bin packing algorithm check # Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
Checks if sequences with the given lengths could fit in the specified number of bins a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
Args: for size in a:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
not_found = True not_found = True
for idx in range(num_bins): for idx in range(n):
if bins[idx] >= size: if bins[idx] >= size:
bins[idx] -= size bins[idx] -= size
not_found = False not_found = False
break break
# If no bin could fit this sequence, packing failed
if not_found: if not_found:
return False return False
@@ -56,132 +40,86 @@ def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
@numba.njit @numba.njit
def pack_group( def ffd_with_result(a: np.ndarray, c: int, start_index: int):
sequence_lengths: np.ndarray, # First-fit-decreasing bin packing (with result return)
group_offset: int,
bin_capacity: int,
max_bins: int,
bin_size: int,
safe_mode: bool = True,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
Args: indices = np.argsort(a)[::-1]
sequence_lengths: Array of sequence lengths a = a[indices]
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
bin_size: Maximum number of sequences per bin
safe_mode: If True, use a more conservative packing approach
Returns: bins: List[Any] = []
List of bins, where each bin contains indices of sequences assigned to it bins_result: List[Any] = []
""" for a_id, size in enumerate(a):
# Get sorting indices and sort lengths in descending order add_new = True
indices = np.argsort(sequence_lengths)[::-1] for idx in range(len(bins)):
sorted_lengths = sequence_lengths[indices] if bins[idx] >= size:
bins[idx] -= size
bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_result[idx].append(indices[a_id] + start_index)
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin add_new = False
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True
for bin_idx, _ in enumerate(bins_remaining_space):
if (
bins_remaining_space[bin_idx] >= size
and len(bins_assigned_sequences[bin_idx]) < bin_size
):
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False
break break
# Create a new bin if needed and if we haven't reached the limit if add_new:
if add_new_bin: bins.append(c - size)
if len(bins_remaining_space) >= max_bins and safe_mode: bins_result.append([indices[a_id] + start_index])
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([global_idx])
# Safety check to avoid infinite bins return bins_result
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences
# Define a standalone function for multiprocessing
def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
return pack_group(
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
)
def pack_parallel(
sequence_lengths: np.ndarray,
bin_capacity: int,
group_size: int,
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
):
"""
Pack sequences into bins using parallel processing
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin as total number of tokens
group_size: Number of sequences to process in each group
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
# Create tasks for parallel processing
tasks = []
for i in range(0, num_items, group_size):
group_lengths = sequence_lengths[i : i + group_size]
max_bins = len(group_lengths) # Allow as many bins as items in the group
tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
return all_bins
@numba.njit @numba.njit
def allocate_sequentially( def allocate(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
): ):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
""" """
Sequential allocator that preserves example order Sequential allocator that preserves example order
Args: Parameters:
sequence_lengths: The lengths of all examples - lengths: The lengths of all examples
rank: The current rank (for distributed training) - rank: The current rank (for distributed training)
bin_capacity: The capacity of each bin (maximum sequence length) - c: The capacity of each bin (maximum sequence length)
num_ranks: Number of ranks (processes/GPUs) - n: Number of ranks
Returns: Returns:
rank_batches: List of batches for the current rank - result: List of batches for the current rank
total_tokens_used: Number of actual example tokens - total_used: Number of actual example tokens
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) - total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
""" """
result = [] result = []
total_used = 0 total_used = 0
@@ -189,9 +127,9 @@ def allocate_sequentially(
# First, do sequential packing into bins # First, do sequential packing into bins
all_bins = [] all_bins = []
current_bin = [0 for i in range(0)] # numba hint current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = bin_capacity remaining_capacity = c
for idx, size in enumerate(sequence_lengths): for idx, size in enumerate(lengths):
if size <= remaining_capacity: if size <= remaining_capacity:
# Example fits in current bin # Example fits in current bin
current_bin.append(idx) current_bin.append(idx)
@@ -202,7 +140,7 @@ def allocate_sequentially(
if current_bin: # Add non-empty bin to all_bins if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin) all_bins.append(current_bin)
current_bin = [idx] current_bin = [idx]
remaining_capacity = bin_capacity - size remaining_capacity = c - size
total_used += size total_used += size
# Add the last bin if not empty # Add the last bin if not empty
@@ -210,227 +148,132 @@ def allocate_sequentially(
all_bins.append(current_bin) all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin # Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), num_ranks): for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx]) result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * bin_capacity return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler): class MultipackBatchSampler(BatchSampler):
""" """Batch sampler class for multipack"""
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
def __init__( def __init__(
self, self,
sampler: Union[Sampler[int], Iterable[int]], sampler: Union[Sampler[int], Iterable[int]],
batch_size: int, # Number of bins per batch batch_size: int,
batch_max_len: int, # Maximum sequence length (bin capacity) batch_max_len: int,
lengths: np.ndarray, # Sequence lengths lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate packing_efficiency_estimate: float = 1.0,
drop_last: bool = False, # Whether to drop final batches (might be incomplete) drop_last: bool = False,
num_count_samples: int = 16, # Number of times to estimate batch count num_count_samples: int = 16,
sequential: bool = False, # Whether to use sequential packing sequential: bool = False,
group_size: int = 100_000, # Size of groups for parallel packing **kwargs,
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
): ):
super().__init__(sampler, batch_size, drop_last) super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size self.batch_size = batch_size
self.batch_max_len = batch_max_len self.batch_max_len = batch_max_len
self.lengths = np.array(lengths, dtype=np.int32) self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential self.sequential = sequential
self.group_size = group_size
self.bin_size = bin_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray) assert isinstance(self.lengths, np.ndarray)
self.epoch = 0 self.epoch = 0
# Efficiency statistics tracking # statistics
self.total_tokens_used = 0 self.eff_total_used = 0
self.total_token_slots = 0 self.eff_total_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length # The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
self.num_count_samples = num_count_samples self.num_count_samples = num_count_samples
# Minimum packed dataset length across all ranks (determined by gather/broadcast) # the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler): if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning( LOG.warning(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?" "using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
) )
def set_epoch(self, epoch: int): def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
""" indices = [idx for idx in self.sampler]
Generate packed batches for training
Args:
set_stats: Whether to update efficiency statistics
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
if self._batches is not None:
return self._batches
# Get indices from the sampler
indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# Get lengths of the selected sequences
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
# Pack sequences into bins using either sequential or parallel packing
if self.sequential: if self.sequential:
bins, total_used, total_slots = allocate_sequentially( batches, total_used, total_slots = allocate_sequentially(
lengths, lengths=lengths,
rank=0, rank=0,
bin_capacity=self.batch_max_len, c=self.batch_max_len,
num_ranks=1, n=1,
) )
# Map bin indices back to original indices
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
else: else:
# Use parallel packing batches, total_used, total_slots = allocate(
all_bins = pack_parallel( lengths=lengths,
lengths, lengths_cumsum=lengths_cumsum,
bin_capacity=self.batch_max_len, rank=0,
group_size=self.group_size, c=self.batch_max_len,
bin_size=self.bin_size, n=1,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
) )
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins)
batches = [ batches = [
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size) [
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
] ]
# Drop last batch if requested and it's incomplete # statistics
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested
if set_stats: if set_stats:
self.total_tokens_used += total_used self.eff_total_used += total_used
self.total_token_slots += total_slots self.eff_total_slots += total_slots
self._batches = batches
return batches return batches
def __iter__(self): def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True) batches = self.generate_batches(set_stats=True)
if self.len_across_ranks: if self.len_across_ranks:
# Truncate batches to ensure all ranks have the same number of batches # make sure the batches we iterate over is truncated to the same min length across all ranks
batches = batches[: self.len_across_ranks] batches = batches[: self.len_across_ranks]
return iter(batches) return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self): def efficiency(self):
""" return self.eff_total_used / self.eff_total_slots
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
def gather_efficiency(self): def gather_efficiency(self):
""" def calc_sample_packing_eff_est(estimates: List[float]):
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: list[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}") LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate return math.floor(0.997 * max(estimates))
max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast( sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est, calc_sample_packing_eff_est,
) )
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = ( sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0 math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
) )
return sample_packing_eff_est return sample_packing_eff_est
def gather_len_batches(self, num): def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]): def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}") LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates)) return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches return min_len_batches
def __len__(self): def __len__(self):
""" if not self.len_across_ranks:
Return the total number of batches that will be yielded by this sampler len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
) )
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches) self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks return self.len_across_ranks

View File

@@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name

View File

@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
lr_groups: list[LrGroup] | None = None lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None adam_epsilon: float | None = None
adam_epsilon2: float | None = None
adam_beta1: float | None = None adam_beta1: float | None = None
adam_beta2: float | None = None adam_beta2: float | None = None
adam_beta3: float | None = None
max_grad_norm: float | None = None max_grad_norm: float | None = None
num_epochs: float = Field(default=1.0) num_epochs: float = Field(default=1.0)

View File

@@ -479,7 +479,7 @@ class TestMultiGPULlama:
"sample_packing": True, "sample_packing": True,
"pad_to_sequence_len": True, "pad_to_sequence_len": True,
"sequence_len": 2048, "sequence_len": 2048,
"val_set_size": 0.05, "val_set_size": 0.1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
}, },

View File

@@ -29,12 +29,12 @@ from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [ MODEL_CONFIGS = [
{ {
"name": "openaccess-ai-collective/tiny-mistral", "name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"expected_activation": apply_lora_mlp_swiglu, "expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16, "dtype": torch.float16,
}, },
{ {
"name": "Qwen/Qwen2-7B", "name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"expected_activation": apply_lora_mlp_swiglu, "expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16, "dtype": torch.float16,
}, },
@@ -44,7 +44,7 @@ MODEL_CONFIGS = [
"dtype": torch.float32, "dtype": torch.float32,
}, },
{ {
"name": "mhenrichsen/gemma-2b", "name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
"expected_activation": apply_lora_mlp_geglu, "expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16, "dtype": torch.float16,
}, },
@@ -156,7 +156,9 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration(): def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model.""" """Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0" "trl-internal-testing/tiny-Gemma2ForCausalLM",
torch_dtype=torch.float16,
device_map="cuda:0",
) )
peft_config = get_peft_config( peft_config = get_peft_config(
{ {

View File

@@ -6,6 +6,8 @@ import logging
import os import os
import unittest import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalconPatched(unittest.TestCase):
Test case for Falcon models Test case for Falcon models
""" """
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir @with_temp_dir
def test_qlora(self, temp_dir): def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -71,6 +74,7 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir @with_temp_dir
def test_ft(self, temp_dir): def test_ft(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -28,7 +28,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -76,7 +76,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,

View File

@@ -56,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
def test_mistral_multipack(self, temp_dir): def test_mistral_multipack(self, temp_dir):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 2048, "sequence_len": 2048,

View File

@@ -15,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -26,6 +26,7 @@ class TestResumeLlama:
Test case for resuming training of llama models Test case for resuming training of llama models
""" """
@require_torch_2_6_0
def test_resume_lora_packed(self, temp_dir): def test_resume_lora_packed(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
@@ -62,6 +63,7 @@ class TestResumeLlama:
"save_total_limit": 5, "save_total_limit": 5,
"max_steps": 15, "max_steps": 15,
"use_tensorboard": True, "use_tensorboard": True,
"save_safetensors": True,
} }
) )
if is_torch_bf16_gpu_available(): if is_torch_bf16_gpu_available():

View File

@@ -19,14 +19,11 @@ class TestE2eEvaluate:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.02, "val_set_size": 0.02,
"special_tokens": { "special_tokens": {
"unk_token": "<unk>", "pad_token": "<|endoftext|>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"datasets": [ "datasets": [
{ {

View File

@@ -6,6 +6,8 @@ import logging
import os import os
import unittest import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalcon(unittest.TestCase):
Test case for falcon Test case for falcon
""" """
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir @with_temp_dir
def test_lora(self, temp_dir): def test_lora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -74,6 +77,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir @with_temp_dir
def test_lora_added_vocab(self, temp_dir): def test_lora_added_vocab(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -129,6 +133,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir @with_temp_dir
def test_ft(self, temp_dir): def test_ft(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code

View File

@@ -30,7 +30,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
@@ -77,7 +77,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.02, "val_set_size": 0.02,

View File

@@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_came_pytorch(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "came_pytorch",
"adam_beta3": 0.9999,
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -414,7 +414,6 @@ class TestDatasetPreparation:
snapshot_path = snapshot_download( snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test", repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset", repo_type="dataset",
local_dir=tmp_ds_path,
) )
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True) shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)