From 5eb01f3df194ce6d663cebf2de8f3fb8fe7ec8e0 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 23 May 2025 21:16:51 -0400 Subject: [PATCH] Fix quarto (#2717) * missing modules * fix quarto complaints --- _quarto.yml | 4 +- docs/getting-started.qmd | 2 +- src/axolotl/integrations/base.py | 46 ++++--- src/axolotl/utils/samplers/multipack.py | 158 ++++++++++++------------ 4 files changed, 106 insertions(+), 104 deletions(-) diff --git a/_quarto.yml b/_quarto.yml index a530e380a..df6992d92 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -129,6 +129,8 @@ quartodoc: - monkeypatch.attention.mllama - monkeypatch.data.batch_dataset_fetcher - monkeypatch.mixtral + - monkeypatch.gradient_checkpointing.offload_cpu + - monkeypatch.gradient_checkpointing.offload_disk - title: Utils desc: Utility functions contents: @@ -145,8 +147,6 @@ quartodoc: - utils.optimizers.adopt - utils.data.pretraining - utils.data.sft - - utils.gradient_checkpointing.offload_cpu - - utils.gradient_checkpointing.offload_disk - title: Schemas desc: Pydantic data models for Axolotl config contents: diff --git a/docs/getting-started.qmd b/docs/getting-started.qmd index 064985e35..6f1b54348 100644 --- a/docs/getting-started.qmd +++ b/docs/getting-started.qmd @@ -180,7 +180,7 @@ Now that you have the basics, you might want to: Check our other guides for details on these topics: - [Configuration Guide](config.qmd) - Full configuration options -- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources +- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources - [Dataset Formats](dataset-formats) - Working with different data formats - [Multi-GPU Training](multi-gpu.qmd) - [Multi-Node Training](multi-node.qmd) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 2beaf667a..eb2b29cbe 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -39,31 +39,39 @@ if TYPE_CHECKING: class BasePlugin: """Base class for all plugins. Defines the interface for plugin methods. - Methods: - register(cfg): Registers the plugin with the given configuration. - load_datasets(cfg): Loads and preprocesses the dataset for training. - pre_model_load(cfg): Performs actions before the model is loaded. - post_model_build(cfg, model): Performs actions after the model is loaded, but + A plugin is a reusable, modular, and self-contained piece of code that extends + the functionality of Axolotl. Plugins can be used to integrate third-party models, + modify the training process, or add new features. + + To create a new plugin, you need to inherit from the BasePlugin class and + implement the required methods. + + Note: + Plugin methods include: + - register(cfg): Registers the plugin with the given configuration. + - load_datasets(cfg): Loads and preprocesses the dataset for training. + - pre_model_load(cfg): Performs actions before the model is loaded. + - post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied. - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. - post_model_load(cfg, model): Performs actions after the model is loaded, + - pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded. + - post_lora_load(cfg, model): Performs actions after LoRA weights are loaded. + - post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters. - post_trainer_create(cfg, trainer): Performs actions after the trainer is + - post_trainer_create(cfg, trainer): Performs actions after the trainer is created. - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and + - create_optimizer(cfg, trainer): Creates and returns an optimizer for training. + - create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler. - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before + - add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training. - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after + - add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training. """ def __init__(self): """Initializes the BasePlugin.""" - def register(self, cfg): # pylint: disable=unused-argument + def register(self, cfg: DictDefault): # pylint: disable=unused-argument """Registers the plugin with the given configuration. Args: @@ -275,10 +283,11 @@ class PluginManager: Attributes: plugins: A list of loaded plugins. - Methods: - get_instance(): Static method to get the singleton instance of `PluginManager`. - register(plugin_name: str): Registers a new plugin by its name. - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. + Note: + Key methods include: + - get_instance(): Static method to get the singleton instance of `PluginManager`. + - register(plugin_name: str): Registers a new plugin by its name. + - pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. """ plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() @@ -534,7 +543,6 @@ class PluginManager: Args: cfg: The configuration for the plugins. - model: The loaded model. """ for plugin in self.plugins.values(): plugin.post_train_unload(cfg) diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index 2df2d9e19..1bfa2ec6e 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -7,7 +7,7 @@ import logging import math from concurrent.futures import ProcessPoolExecutor from multiprocessing import cpu_count, get_context -from typing import Iterable, Union +from typing import Iterable, Iterator, Union import numba import numpy as np @@ -20,19 +20,19 @@ LOG.setLevel(logging.INFO) @numba.njit -def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int): - """ - First-fit-decreasing bin packing algorithm check +def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int) -> bool: + """First-fit-decreasing bin packing algorithm check. - Checks if sequences with the given lengths could fit in the specified number of bins + Checks if sequences with the given lengths could fit in the specified number of + bins. Args: - sequence_lengths: Array of sequence lengths - bin_capacity: Maximum capacity of each bin - num_bins: Number of bins available + sequence_lengths: Array of sequence lengths. + bin_capacity: Maximum capacity of each bin. + num_bins: Number of bins available. Returns: - True if all sequences can be packed, False otherwise + `True` if all sequences can be packed, `False` otherwise. """ # Sort sequence lengths in descending order for optimal packing sequence_lengths = np.sort(sequence_lengths)[::-1] @@ -63,20 +63,19 @@ def pack_group( max_bins: int, bin_size: int, safe_mode: bool = True, -): - """ - Pack a group of sequences into bins using First-Fit Decreasing algorithm +) -> list[list[int]]: + """Pack a group of sequences into bins using First-Fit Decreasing algorithm. Args: - sequence_lengths: Array of sequence lengths - group_offset: Offset to apply to indices when returning results - bin_capacity: Maximum capacity of each bin - max_bins: Maximum number of bins to use - bin_size: Maximum number of sequences per bin - safe_mode: If True, use a more conservative packing approach + sequence_lengths: Array of sequence lengths. + group_offset: Offset to apply to indices when returning results. + bin_capacity: Maximum capacity of each bin. + max_bins: Maximum number of bins to use. + bin_size: Maximum number of sequences per bin. + safe_mode: If True, use a more conservative packing approach. Returns: - List of bins, where each bin contains indices of sequences assigned to it + List of bins, where each bin contains indices of sequences assigned to it. """ bins_remaining_space: list = [] # Tracks remaining capacity in each bin bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin @@ -111,8 +110,10 @@ def pack_group( return bins_assigned_sequences -# Define a standalone function for multiprocessing -def _process_group(args): +def _process_group( + args: tuple[np.ndarray, int, int, int, int, bool], +) -> list[list[int]]: + """Standalone function for multiprocessing.""" group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args return pack_group( group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode @@ -127,22 +128,21 @@ def pack_parallel( num_processes: int | None = None, safe_mode: bool = True, mp_start_method: str | None = "spawn", -): - """ - Pack sequences into bins using parallel processing +) -> list[list[int]]: + """Pack sequences into bins using parallel processing. Args: - sequence_lengths: Array of sequence lengths - bin_capacity: Maximum capacity of each bin as total number of tokens - group_size: Number of sequences to process in each group - bin_size: Maximum number of bins to use - num_processes: Number of parallel processes to use - safe_mode: If True, use a more conservative packing approach + sequence_lengths: Array of sequence lengths. + bin_capacity: Maximum capacity of each bin as total number of tokens. + group_size: Number of sequences to process in each group. + bin_size: Maximum number of bins to use. + num_processes: Number of parallel processes to use. + safe_mode: If True, use a more conservative packing approach. mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver'). 'spawn' is often safer with Numba/PyTorch. Set to None to use system default. Returns: - List of bins, where each bin contains indices of sequences assigned to it + List of bins, where each bin contains indices of sequences assigned to it. """ num_items = len(sequence_lengths) if num_processes is None: @@ -191,20 +191,20 @@ def pack_parallel( @numba.njit def allocate_sequentially( sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int -): - """ - Sequential allocator that preserves example order +) -> tuple[list[list[int]], int, int]: + """Sequential allocator that preserves example order. Args: - sequence_lengths: The lengths of all examples - rank: The current rank (for distributed training) - bin_capacity: The capacity of each bin (maximum sequence length) - num_ranks: Number of ranks (processes/GPUs) + sequence_lengths: The lengths of all examples. + rank: The current rank (for distributed training). + bin_capacity: The capacity of each bin (maximum sequence length). + num_ranks: Number of ranks (processes / GPUs). Returns: - rank_batches: List of batches for the current rank - total_tokens_used: Number of actual example tokens - total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity) + rank_batches: List of batches for the current rank. + total_tokens_used: Number of actual example tokens. + total_token_slots: Maximum theoretical number of example tokens (number of bins + * bin capacity). """ result = [] total_used = 0 @@ -240,8 +240,7 @@ def allocate_sequentially( class MultipackBatchSampler(BatchSampler): - """ - Batch sampler class for efficient packing of variable-length sequences + """Batch sampler class for efficient packing of variable-length sequences This sampler packs sequences into fixed-capacity bins (batches) to maximize GPU memory utilization and training throughput by reducing padding. @@ -250,6 +249,9 @@ class MultipackBatchSampler(BatchSampler): sequential packing (preserving original sequence order). """ + _batches: list[list[list[int]]] | None = None + _len_across_ranks: int | None = None + def __init__( self, sampler: Union[Sampler[int], Iterable[int]], @@ -287,11 +289,6 @@ class MultipackBatchSampler(BatchSampler): # The number of times to calculate batches to determine minimum packed dataset length self.num_count_samples = num_count_samples - # Minimum packed dataset length across all ranks (determined by gather/broadcast) - self.len_across_ranks = None - - # Cache for batches - self._batches = None if self.sequential and not isinstance(sampler, SequentialSampler): LOG.warning( @@ -303,16 +300,15 @@ class MultipackBatchSampler(BatchSampler): self.epoch = epoch self._batches = None # Invalidate batch cache - def generate_batches(self, set_stats=False): - """ - Generate packed batches for training + def generate_batches(self, set_stats: bool = False) -> list[list[list[int]]]: + """Generate packed batches for training. Args: - set_stats: Whether to update efficiency statistics + set_stats: Whether to update efficiency statistics. Returns: - List of batches, where each batch contains multiple bins, - and each bin contains multiple sequence indices + List of batches, where each batch contains multiple bins, and each bin + contains multiple sequence indices. """ if self._batches is not None: return self._batches @@ -375,23 +371,21 @@ class MultipackBatchSampler(BatchSampler): self._batches = batches return batches - def __iter__(self): - """ - Return an iterator over batches + def __iter__(self) -> Iterator[list[list[int]]]: + """Return an iterator over batches. - The batches are truncated to match the minimum number of batches across all ranks - to ensure distributed training balance + The batches are truncated to match the minimum number of batches across all + ranks to ensure distributed training balance. """ batches = self.generate_batches(set_stats=True) - if self.len_across_ranks: + if self._len_across_ranks: # Truncate batches to ensure all ranks have the same number of batches - batches = batches[: self.len_across_ranks] + batches = batches[: self._len_across_ranks] return iter(batches) - def efficiency(self): - """ - Calculate the packing efficiency (ratio of tokens used to total token slots) - Higher is better - 1.0 would mean perfect packing with no wasted space + def efficiency(self) -> float: + """Calculate the packing efficiency (ratio of tokens used to total token slots). + Higher is better - 1.0 would mean perfect packing with no wasted space. """ if self.total_token_slots == 0: self.generate_batches(set_stats=True) @@ -400,10 +394,12 @@ class MultipackBatchSampler(BatchSampler): # Return a Python float instead of potentially a numpy float return float(self.total_tokens_used / self.total_token_slots) - def gather_efficiency(self): - """ - Gather and synchronize packing efficiency estimates across all distributed ranks - Returns a conservative efficiency estimate based on the measurements + def gather_efficiency(self) -> float: + """Gather and synchronize packing efficiency estimates across all distributed + ranks. + + Returns: + A conservative efficiency estimate based on the measurements. """ def calc_sample_packing_eff_est(estimates: list[float]): @@ -424,13 +420,12 @@ class MultipackBatchSampler(BatchSampler): ) return sample_packing_eff_est - def gather_len_batches(self, num): - """ - Gather and synchronize batch counts across all distributed ranks - Returns the minimum number of batches available on any rank + def gather_len_batches(self, num: int) -> int: + """Gather and synchronize batch counts across all distributed ranks. Returns + the minimum number of batches available on any rank. """ - def calc_min_len(estimates: list[(int, float)]): + def calc_min_len(estimates: list[int]) -> int: LOG.info(f"gather_len_batches: {repr(estimates)}") return math.floor(min(estimates)) @@ -438,22 +433,21 @@ class MultipackBatchSampler(BatchSampler): min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len) return min_len_batches - def __len__(self): - """ - Return the total number of batches that will be yielded by this sampler + def __len__(self) -> int: + """Return the total number of batches that will be yielded by this sampler. - This is calculated as the minimum number of batches available on any rank - to ensure balanced distributed training + This is calculated as the minimum number of batches available on any rank to + ensure balanced distributed training. """ if self._batches is None: self._batches = self.generate_batches(set_stats=True) - if self.len_across_ranks is None: + if self._len_across_ranks is None: # Sample multiple times to get stable estimate len_batches = min( # pylint: disable=consider-using-generator [len(self._batches) for _ in range(self.num_count_samples)] ) # Gather minimum across all ranks - self.len_across_ranks = self.gather_len_batches(len_batches) + self._len_across_ranks = self.gather_len_batches(len_batches) - return self.len_across_ranks + return self._len_across_ranks