@@ -39,31 +39,39 @@ if TYPE_CHECKING:
|
||||
class BasePlugin:
|
||||
"""Base class for all plugins. Defines the interface for plugin methods.
|
||||
|
||||
Methods:
|
||||
register(cfg): Registers the plugin with the given configuration.
|
||||
load_datasets(cfg): Loads and preprocesses the dataset for training.
|
||||
pre_model_load(cfg): Performs actions before the model is loaded.
|
||||
post_model_build(cfg, model): Performs actions after the model is loaded, but
|
||||
A plugin is a reusable, modular, and self-contained piece of code that extends
|
||||
the functionality of Axolotl. Plugins can be used to integrate third-party models,
|
||||
modify the training process, or add new features.
|
||||
|
||||
To create a new plugin, you need to inherit from the BasePlugin class and
|
||||
implement the required methods.
|
||||
|
||||
Note:
|
||||
Plugin methods include:
|
||||
- register(cfg): Registers the plugin with the given configuration.
|
||||
- load_datasets(cfg): Loads and preprocesses the dataset for training.
|
||||
- pre_model_load(cfg): Performs actions before the model is loaded.
|
||||
- post_model_build(cfg, model): Performs actions after the model is loaded, but
|
||||
before LoRA adapters are applied.
|
||||
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
||||
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
||||
post_model_load(cfg, model): Performs actions after the model is loaded,
|
||||
- pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
||||
- post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
||||
- post_model_load(cfg, model): Performs actions after the model is loaded,
|
||||
inclusive of any adapters.
|
||||
post_trainer_create(cfg, trainer): Performs actions after the trainer is
|
||||
- post_trainer_create(cfg, trainer): Performs actions after the trainer is
|
||||
created.
|
||||
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
||||
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
|
||||
- create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
||||
- create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and
|
||||
returns a learning rate scheduler.
|
||||
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
|
||||
- add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before
|
||||
training.
|
||||
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
|
||||
- add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after
|
||||
training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the BasePlugin."""
|
||||
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration.
|
||||
|
||||
Args:
|
||||
@@ -275,10 +283,11 @@ class PluginManager:
|
||||
Attributes:
|
||||
plugins: A list of loaded plugins.
|
||||
|
||||
Methods:
|
||||
get_instance(): Static method to get the singleton instance of `PluginManager`.
|
||||
register(plugin_name: str): Registers a new plugin by its name.
|
||||
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||
Note:
|
||||
Key methods include:
|
||||
- get_instance(): Static method to get the singleton instance of `PluginManager`.
|
||||
- register(plugin_name: str): Registers a new plugin by its name.
|
||||
- pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||
"""
|
||||
|
||||
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
|
||||
@@ -534,7 +543,6 @@ class PluginManager:
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugins.
|
||||
model: The loaded model.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
from typing import Iterable, Union
|
||||
from typing import Iterable, Iterator, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
@@ -20,19 +20,19 @@ LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
|
||||
"""
|
||||
First-fit-decreasing bin packing algorithm check
|
||||
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int) -> bool:
|
||||
"""First-fit-decreasing bin packing algorithm check.
|
||||
|
||||
Checks if sequences with the given lengths could fit in the specified number of bins
|
||||
Checks if sequences with the given lengths could fit in the specified number of
|
||||
bins.
|
||||
|
||||
Args:
|
||||
sequence_lengths: Array of sequence lengths
|
||||
bin_capacity: Maximum capacity of each bin
|
||||
num_bins: Number of bins available
|
||||
sequence_lengths: Array of sequence lengths.
|
||||
bin_capacity: Maximum capacity of each bin.
|
||||
num_bins: Number of bins available.
|
||||
|
||||
Returns:
|
||||
True if all sequences can be packed, False otherwise
|
||||
`True` if all sequences can be packed, `False` otherwise.
|
||||
"""
|
||||
# Sort sequence lengths in descending order for optimal packing
|
||||
sequence_lengths = np.sort(sequence_lengths)[::-1]
|
||||
@@ -63,20 +63,19 @@ def pack_group(
|
||||
max_bins: int,
|
||||
bin_size: int,
|
||||
safe_mode: bool = True,
|
||||
):
|
||||
"""
|
||||
Pack a group of sequences into bins using First-Fit Decreasing algorithm
|
||||
) -> list[list[int]]:
|
||||
"""Pack a group of sequences into bins using First-Fit Decreasing algorithm.
|
||||
|
||||
Args:
|
||||
sequence_lengths: Array of sequence lengths
|
||||
group_offset: Offset to apply to indices when returning results
|
||||
bin_capacity: Maximum capacity of each bin
|
||||
max_bins: Maximum number of bins to use
|
||||
bin_size: Maximum number of sequences per bin
|
||||
safe_mode: If True, use a more conservative packing approach
|
||||
sequence_lengths: Array of sequence lengths.
|
||||
group_offset: Offset to apply to indices when returning results.
|
||||
bin_capacity: Maximum capacity of each bin.
|
||||
max_bins: Maximum number of bins to use.
|
||||
bin_size: Maximum number of sequences per bin.
|
||||
safe_mode: If True, use a more conservative packing approach.
|
||||
|
||||
Returns:
|
||||
List of bins, where each bin contains indices of sequences assigned to it
|
||||
List of bins, where each bin contains indices of sequences assigned to it.
|
||||
"""
|
||||
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
|
||||
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
|
||||
@@ -111,8 +110,10 @@ def pack_group(
|
||||
return bins_assigned_sequences
|
||||
|
||||
|
||||
# Define a standalone function for multiprocessing
|
||||
def _process_group(args):
|
||||
def _process_group(
|
||||
args: tuple[np.ndarray, int, int, int, int, bool],
|
||||
) -> list[list[int]]:
|
||||
"""Standalone function for multiprocessing."""
|
||||
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
|
||||
return pack_group(
|
||||
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
|
||||
@@ -127,22 +128,21 @@ def pack_parallel(
|
||||
num_processes: int | None = None,
|
||||
safe_mode: bool = True,
|
||||
mp_start_method: str | None = "spawn",
|
||||
):
|
||||
"""
|
||||
Pack sequences into bins using parallel processing
|
||||
) -> list[list[int]]:
|
||||
"""Pack sequences into bins using parallel processing.
|
||||
|
||||
Args:
|
||||
sequence_lengths: Array of sequence lengths
|
||||
bin_capacity: Maximum capacity of each bin as total number of tokens
|
||||
group_size: Number of sequences to process in each group
|
||||
bin_size: Maximum number of bins to use
|
||||
num_processes: Number of parallel processes to use
|
||||
safe_mode: If True, use a more conservative packing approach
|
||||
sequence_lengths: Array of sequence lengths.
|
||||
bin_capacity: Maximum capacity of each bin as total number of tokens.
|
||||
group_size: Number of sequences to process in each group.
|
||||
bin_size: Maximum number of bins to use.
|
||||
num_processes: Number of parallel processes to use.
|
||||
safe_mode: If True, use a more conservative packing approach.
|
||||
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
|
||||
'spawn' is often safer with Numba/PyTorch.
|
||||
Set to None to use system default.
|
||||
Returns:
|
||||
List of bins, where each bin contains indices of sequences assigned to it
|
||||
List of bins, where each bin contains indices of sequences assigned to it.
|
||||
"""
|
||||
num_items = len(sequence_lengths)
|
||||
if num_processes is None:
|
||||
@@ -191,20 +191,20 @@ def pack_parallel(
|
||||
@numba.njit
|
||||
def allocate_sequentially(
|
||||
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
|
||||
):
|
||||
"""
|
||||
Sequential allocator that preserves example order
|
||||
) -> tuple[list[list[int]], int, int]:
|
||||
"""Sequential allocator that preserves example order.
|
||||
|
||||
Args:
|
||||
sequence_lengths: The lengths of all examples
|
||||
rank: The current rank (for distributed training)
|
||||
bin_capacity: The capacity of each bin (maximum sequence length)
|
||||
num_ranks: Number of ranks (processes/GPUs)
|
||||
sequence_lengths: The lengths of all examples.
|
||||
rank: The current rank (for distributed training).
|
||||
bin_capacity: The capacity of each bin (maximum sequence length).
|
||||
num_ranks: Number of ranks (processes / GPUs).
|
||||
|
||||
Returns:
|
||||
rank_batches: List of batches for the current rank
|
||||
total_tokens_used: Number of actual example tokens
|
||||
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
|
||||
rank_batches: List of batches for the current rank.
|
||||
total_tokens_used: Number of actual example tokens.
|
||||
total_token_slots: Maximum theoretical number of example tokens (number of bins
|
||||
* bin capacity).
|
||||
"""
|
||||
result = []
|
||||
total_used = 0
|
||||
@@ -240,8 +240,7 @@ def allocate_sequentially(
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""
|
||||
Batch sampler class for efficient packing of variable-length sequences
|
||||
"""Batch sampler class for efficient packing of variable-length sequences
|
||||
|
||||
This sampler packs sequences into fixed-capacity bins (batches) to maximize
|
||||
GPU memory utilization and training throughput by reducing padding.
|
||||
@@ -250,6 +249,9 @@ class MultipackBatchSampler(BatchSampler):
|
||||
sequential packing (preserving original sequence order).
|
||||
"""
|
||||
|
||||
_batches: list[list[list[int]]] | None = None
|
||||
_len_across_ranks: int | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler: Union[Sampler[int], Iterable[int]],
|
||||
@@ -287,11 +289,6 @@ class MultipackBatchSampler(BatchSampler):
|
||||
|
||||
# The number of times to calculate batches to determine minimum packed dataset length
|
||||
self.num_count_samples = num_count_samples
|
||||
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
|
||||
self.len_across_ranks = None
|
||||
|
||||
# Cache for batches
|
||||
self._batches = None
|
||||
|
||||
if self.sequential and not isinstance(sampler, SequentialSampler):
|
||||
LOG.warning(
|
||||
@@ -303,16 +300,15 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self.epoch = epoch
|
||||
self._batches = None # Invalidate batch cache
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
"""
|
||||
Generate packed batches for training
|
||||
def generate_batches(self, set_stats: bool = False) -> list[list[list[int]]]:
|
||||
"""Generate packed batches for training.
|
||||
|
||||
Args:
|
||||
set_stats: Whether to update efficiency statistics
|
||||
set_stats: Whether to update efficiency statistics.
|
||||
|
||||
Returns:
|
||||
List of batches, where each batch contains multiple bins,
|
||||
and each bin contains multiple sequence indices
|
||||
List of batches, where each batch contains multiple bins, and each bin
|
||||
contains multiple sequence indices.
|
||||
"""
|
||||
if self._batches is not None:
|
||||
return self._batches
|
||||
@@ -375,23 +371,21 @@ class MultipackBatchSampler(BatchSampler):
|
||||
self._batches = batches
|
||||
return batches
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Return an iterator over batches
|
||||
def __iter__(self) -> Iterator[list[list[int]]]:
|
||||
"""Return an iterator over batches.
|
||||
|
||||
The batches are truncated to match the minimum number of batches across all ranks
|
||||
to ensure distributed training balance
|
||||
The batches are truncated to match the minimum number of batches across all
|
||||
ranks to ensure distributed training balance.
|
||||
"""
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
if self.len_across_ranks:
|
||||
if self._len_across_ranks:
|
||||
# Truncate batches to ensure all ranks have the same number of batches
|
||||
batches = batches[: self.len_across_ranks]
|
||||
batches = batches[: self._len_across_ranks]
|
||||
return iter(batches)
|
||||
|
||||
def efficiency(self):
|
||||
"""
|
||||
Calculate the packing efficiency (ratio of tokens used to total token slots)
|
||||
Higher is better - 1.0 would mean perfect packing with no wasted space
|
||||
def efficiency(self) -> float:
|
||||
"""Calculate the packing efficiency (ratio of tokens used to total token slots).
|
||||
Higher is better - 1.0 would mean perfect packing with no wasted space.
|
||||
"""
|
||||
if self.total_token_slots == 0:
|
||||
self.generate_batches(set_stats=True)
|
||||
@@ -400,10 +394,12 @@ class MultipackBatchSampler(BatchSampler):
|
||||
# Return a Python float instead of potentially a numpy float
|
||||
return float(self.total_tokens_used / self.total_token_slots)
|
||||
|
||||
def gather_efficiency(self):
|
||||
"""
|
||||
Gather and synchronize packing efficiency estimates across all distributed ranks
|
||||
Returns a conservative efficiency estimate based on the measurements
|
||||
def gather_efficiency(self) -> float:
|
||||
"""Gather and synchronize packing efficiency estimates across all distributed
|
||||
ranks.
|
||||
|
||||
Returns:
|
||||
A conservative efficiency estimate based on the measurements.
|
||||
"""
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: list[float]):
|
||||
@@ -424,13 +420,12 @@ class MultipackBatchSampler(BatchSampler):
|
||||
)
|
||||
return sample_packing_eff_est
|
||||
|
||||
def gather_len_batches(self, num):
|
||||
"""
|
||||
Gather and synchronize batch counts across all distributed ranks
|
||||
Returns the minimum number of batches available on any rank
|
||||
def gather_len_batches(self, num: int) -> int:
|
||||
"""Gather and synchronize batch counts across all distributed ranks. Returns
|
||||
the minimum number of batches available on any rank.
|
||||
"""
|
||||
|
||||
def calc_min_len(estimates: list[(int, float)]):
|
||||
def calc_min_len(estimates: list[int]) -> int:
|
||||
LOG.info(f"gather_len_batches: {repr(estimates)}")
|
||||
return math.floor(min(estimates))
|
||||
|
||||
@@ -438,22 +433,21 @@ class MultipackBatchSampler(BatchSampler):
|
||||
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
|
||||
return min_len_batches
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return the total number of batches that will be yielded by this sampler
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of batches that will be yielded by this sampler.
|
||||
|
||||
This is calculated as the minimum number of batches available on any rank
|
||||
to ensure balanced distributed training
|
||||
This is calculated as the minimum number of batches available on any rank to
|
||||
ensure balanced distributed training.
|
||||
"""
|
||||
if self._batches is None:
|
||||
self._batches = self.generate_batches(set_stats=True)
|
||||
|
||||
if self.len_across_ranks is None:
|
||||
if self._len_across_ranks is None:
|
||||
# Sample multiple times to get stable estimate
|
||||
len_batches = min( # pylint: disable=consider-using-generator
|
||||
[len(self._batches) for _ in range(self.num_count_samples)]
|
||||
)
|
||||
# Gather minimum across all ranks
|
||||
self.len_across_ranks = self.gather_len_batches(len_batches)
|
||||
self._len_across_ranks = self.gather_len_batches(len_batches)
|
||||
|
||||
return self.len_across_ranks
|
||||
return self._len_across_ranks
|
||||
|
||||
Reference in New Issue
Block a user