Fix quarto (#2717)

* missing modules

* fix quarto complaints
This commit is contained in:
Dan Saunders
2025-05-23 21:16:51 -04:00
committed by GitHub
parent d27c35ac44
commit 5eb01f3df1
4 changed files with 106 additions and 104 deletions

View File

@@ -129,6 +129,8 @@ quartodoc:
- monkeypatch.attention.mllama - monkeypatch.attention.mllama
- monkeypatch.data.batch_dataset_fetcher - monkeypatch.data.batch_dataset_fetcher
- monkeypatch.mixtral - monkeypatch.mixtral
- monkeypatch.gradient_checkpointing.offload_cpu
- monkeypatch.gradient_checkpointing.offload_disk
- title: Utils - title: Utils
desc: Utility functions desc: Utility functions
contents: contents:
@@ -145,8 +147,6 @@ quartodoc:
- utils.optimizers.adopt - utils.optimizers.adopt
- utils.data.pretraining - utils.data.pretraining
- utils.data.sft - utils.data.sft
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- title: Schemas - title: Schemas
desc: Pydantic data models for Axolotl config desc: Pydantic data models for Axolotl config
contents: contents:

View File

@@ -180,7 +180,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics: Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options - [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources - [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats - [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd) - [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd) - [Multi-Node Training](multi-node.qmd)

View File

@@ -39,31 +39,39 @@ if TYPE_CHECKING:
class BasePlugin: class BasePlugin:
"""Base class for all plugins. Defines the interface for plugin methods. """Base class for all plugins. Defines the interface for plugin methods.
Methods: A plugin is a reusable, modular, and self-contained piece of code that extends
register(cfg): Registers the plugin with the given configuration. the functionality of Axolotl. Plugins can be used to integrate third-party models,
load_datasets(cfg): Loads and preprocesses the dataset for training. modify the training process, or add new features.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but To create a new plugin, you need to inherit from the BasePlugin class and
implement the required methods.
Note:
Plugin methods include:
- register(cfg): Registers the plugin with the given configuration.
- load_datasets(cfg): Loads and preprocesses the dataset for training.
- pre_model_load(cfg): Performs actions before the model is loaded.
- post_model_build(cfg, model): Performs actions after the model is loaded, but
before LoRA adapters are applied. before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, - post_model_load(cfg, model): Performs actions after the model is loaded,
inclusive of any adapters. inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is - post_trainer_create(cfg, trainer): Performs actions after the trainer is
created. created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
returns a learning rate scheduler. returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
training. training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
training. training.
""" """
def __init__(self): def __init__(self):
"""Initializes the BasePlugin.""" """Initializes the BasePlugin."""
def register(self, cfg): # pylint: disable=unused-argument def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration. """Registers the plugin with the given configuration.
Args: Args:
@@ -275,10 +283,11 @@ class PluginManager:
Attributes: Attributes:
plugins: A list of loaded plugins. plugins: A list of loaded plugins.
Methods: Note:
get_instance(): Static method to get the singleton instance of `PluginManager`. Key methods include:
register(plugin_name: str): Registers a new plugin by its name. - get_instance(): Static method to get the singleton instance of `PluginManager`.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. - register(plugin_name: str): Registers a new plugin by its name.
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
""" """
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
@@ -534,7 +543,6 @@ class PluginManager:
Args: Args:
cfg: The configuration for the plugins. cfg: The configuration for the plugins.
model: The loaded model.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins.values():
plugin.post_train_unload(cfg) plugin.post_train_unload(cfg)

View File

@@ -7,7 +7,7 @@ import logging
import math import math
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context from multiprocessing import cpu_count, get_context
from typing import Iterable, Union from typing import Iterable, Iterator, Union
import numba import numba
import numpy as np import numpy as np
@@ -20,19 +20,19 @@ LOG.setLevel(logging.INFO)
@numba.njit @numba.njit
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int) -> bool:
""" """First-fit-decreasing bin packing algorithm check.
First-fit-decreasing bin packing algorithm check
Checks if sequences with the given lengths could fit in the specified number of bins Checks if sequences with the given lengths could fit in the specified number of
bins.
Args: Args:
sequence_lengths: Array of sequence lengths sequence_lengths: Array of sequence lengths.
bin_capacity: Maximum capacity of each bin bin_capacity: Maximum capacity of each bin.
num_bins: Number of bins available num_bins: Number of bins available.
Returns: Returns:
True if all sequences can be packed, False otherwise `True` if all sequences can be packed, `False` otherwise.
""" """
# Sort sequence lengths in descending order for optimal packing # Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1] sequence_lengths = np.sort(sequence_lengths)[::-1]
@@ -63,20 +63,19 @@ def pack_group(
max_bins: int, max_bins: int,
bin_size: int, bin_size: int,
safe_mode: bool = True, safe_mode: bool = True,
): ) -> list[list[int]]:
""" """Pack a group of sequences into bins using First-Fit Decreasing algorithm.
Pack a group of sequences into bins using First-Fit Decreasing algorithm
Args: Args:
sequence_lengths: Array of sequence lengths sequence_lengths: Array of sequence lengths.
group_offset: Offset to apply to indices when returning results group_offset: Offset to apply to indices when returning results.
bin_capacity: Maximum capacity of each bin bin_capacity: Maximum capacity of each bin.
max_bins: Maximum number of bins to use max_bins: Maximum number of bins to use.
bin_size: Maximum number of sequences per bin bin_size: Maximum number of sequences per bin.
safe_mode: If True, use a more conservative packing approach safe_mode: If True, use a more conservative packing approach.
Returns: Returns:
List of bins, where each bin contains indices of sequences assigned to it List of bins, where each bin contains indices of sequences assigned to it.
""" """
bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
@@ -111,8 +110,10 @@ def pack_group(
return bins_assigned_sequences return bins_assigned_sequences
# Define a standalone function for multiprocessing def _process_group(
def _process_group(args): args: tuple[np.ndarray, int, int, int, int, bool],
) -> list[list[int]]:
"""Standalone function for multiprocessing."""
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
return pack_group( return pack_group(
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
@@ -127,22 +128,21 @@ def pack_parallel(
num_processes: int | None = None, num_processes: int | None = None,
safe_mode: bool = True, safe_mode: bool = True,
mp_start_method: str | None = "spawn", mp_start_method: str | None = "spawn",
): ) -> list[list[int]]:
""" """Pack sequences into bins using parallel processing.
Pack sequences into bins using parallel processing
Args: Args:
sequence_lengths: Array of sequence lengths sequence_lengths: Array of sequence lengths.
bin_capacity: Maximum capacity of each bin as total number of tokens bin_capacity: Maximum capacity of each bin as total number of tokens.
group_size: Number of sequences to process in each group group_size: Number of sequences to process in each group.
bin_size: Maximum number of bins to use bin_size: Maximum number of bins to use.
num_processes: Number of parallel processes to use num_processes: Number of parallel processes to use.
safe_mode: If True, use a more conservative packing approach safe_mode: If True, use a more conservative packing approach.
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver'). mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
'spawn' is often safer with Numba/PyTorch. 'spawn' is often safer with Numba/PyTorch.
Set to None to use system default. Set to None to use system default.
Returns: Returns:
List of bins, where each bin contains indices of sequences assigned to it List of bins, where each bin contains indices of sequences assigned to it.
""" """
num_items = len(sequence_lengths) num_items = len(sequence_lengths)
if num_processes is None: if num_processes is None:
@@ -191,20 +191,20 @@ def pack_parallel(
@numba.njit @numba.njit
def allocate_sequentially( def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
): ) -> tuple[list[list[int]], int, int]:
""" """Sequential allocator that preserves example order.
Sequential allocator that preserves example order
Args: Args:
sequence_lengths: The lengths of all examples sequence_lengths: The lengths of all examples.
rank: The current rank (for distributed training) rank: The current rank (for distributed training).
bin_capacity: The capacity of each bin (maximum sequence length) bin_capacity: The capacity of each bin (maximum sequence length).
num_ranks: Number of ranks (processes/GPUs) num_ranks: Number of ranks (processes / GPUs).
Returns: Returns:
rank_batches: List of batches for the current rank rank_batches: List of batches for the current rank.
total_tokens_used: Number of actual example tokens total_tokens_used: Number of actual example tokens.
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) total_token_slots: Maximum theoretical number of example tokens (number of bins
* bin capacity).
""" """
result = [] result = []
total_used = 0 total_used = 0
@@ -240,8 +240,7 @@ def allocate_sequentially(
class MultipackBatchSampler(BatchSampler): class MultipackBatchSampler(BatchSampler):
""" """Batch sampler class for efficient packing of variable-length sequences
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding. GPU memory utilization and training throughput by reducing padding.
@@ -250,6 +249,9 @@ class MultipackBatchSampler(BatchSampler):
sequential packing (preserving original sequence order). sequential packing (preserving original sequence order).
""" """
_batches: list[list[list[int]]] | None = None
_len_across_ranks: int | None = None
def __init__( def __init__(
self, self,
sampler: Union[Sampler[int], Iterable[int]], sampler: Union[Sampler[int], Iterable[int]],
@@ -287,11 +289,6 @@ class MultipackBatchSampler(BatchSampler):
# The number of times to calculate batches to determine minimum packed dataset length # The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples self.num_count_samples = num_count_samples
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler): if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning( LOG.warning(
@@ -303,16 +300,15 @@ class MultipackBatchSampler(BatchSampler):
self.epoch = epoch self.epoch = epoch
self._batches = None # Invalidate batch cache self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats: bool = False) -> list[list[list[int]]]:
""" """Generate packed batches for training.
Generate packed batches for training
Args: Args:
set_stats: Whether to update efficiency statistics set_stats: Whether to update efficiency statistics.
Returns: Returns:
List of batches, where each batch contains multiple bins, List of batches, where each batch contains multiple bins, and each bin
and each bin contains multiple sequence indices contains multiple sequence indices.
""" """
if self._batches is not None: if self._batches is not None:
return self._batches return self._batches
@@ -375,23 +371,21 @@ class MultipackBatchSampler(BatchSampler):
self._batches = batches self._batches = batches
return batches return batches
def __iter__(self): def __iter__(self) -> Iterator[list[list[int]]]:
""" """Return an iterator over batches.
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks The batches are truncated to match the minimum number of batches across all
to ensure distributed training balance ranks to ensure distributed training balance.
""" """
batches = self.generate_batches(set_stats=True) batches = self.generate_batches(set_stats=True)
if self.len_across_ranks: if self._len_across_ranks:
# Truncate batches to ensure all ranks have the same number of batches # Truncate batches to ensure all ranks have the same number of batches
batches = batches[: self.len_across_ranks] batches = batches[: self._len_across_ranks]
return iter(batches) return iter(batches)
def efficiency(self): def efficiency(self) -> float:
""" """Calculate the packing efficiency (ratio of tokens used to total token slots).
Calculate the packing efficiency (ratio of tokens used to total token slots) Higher is better - 1.0 would mean perfect packing with no wasted space.
Higher is better - 1.0 would mean perfect packing with no wasted space
""" """
if self.total_token_slots == 0: if self.total_token_slots == 0:
self.generate_batches(set_stats=True) self.generate_batches(set_stats=True)
@@ -400,10 +394,12 @@ class MultipackBatchSampler(BatchSampler):
# Return a Python float instead of potentially a numpy float # Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots) return float(self.total_tokens_used / self.total_token_slots)
def gather_efficiency(self): def gather_efficiency(self) -> float:
""" """Gather and synchronize packing efficiency estimates across all distributed
Gather and synchronize packing efficiency estimates across all distributed ranks ranks.
Returns a conservative efficiency estimate based on the measurements
Returns:
A conservative efficiency estimate based on the measurements.
""" """
def calc_sample_packing_eff_est(estimates: list[float]): def calc_sample_packing_eff_est(estimates: list[float]):
@@ -424,13 +420,12 @@ class MultipackBatchSampler(BatchSampler):
) )
return sample_packing_eff_est return sample_packing_eff_est
def gather_len_batches(self, num): def gather_len_batches(self, num: int) -> int:
""" """Gather and synchronize batch counts across all distributed ranks. Returns
Gather and synchronize batch counts across all distributed ranks the minimum number of batches available on any rank.
Returns the minimum number of batches available on any rank
""" """
def calc_min_len(estimates: list[(int, float)]): def calc_min_len(estimates: list[int]) -> int:
LOG.info(f"gather_len_batches: {repr(estimates)}") LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates)) return math.floor(min(estimates))
@@ -438,22 +433,21 @@ class MultipackBatchSampler(BatchSampler):
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches return min_len_batches
def __len__(self): def __len__(self) -> int:
""" """Return the total number of batches that will be yielded by this sampler.
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank This is calculated as the minimum number of batches available on any rank to
to ensure balanced distributed training ensure balanced distributed training.
""" """
if self._batches is None: if self._batches is None:
self._batches = self.generate_batches(set_stats=True) self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None: if self._len_across_ranks is None:
# Sample multiple times to get stable estimate # Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)] [len(self._batches) for _ in range(self.num_count_samples)]
) )
# Gather minimum across all ranks # Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches) self._len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks return self._len_across_ranks