diff --git a/docs/config.qmd b/docs/config.qmd index 787632c50..a68afde04 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -623,6 +623,9 @@ ddp_broadcast_buffers: # Sequence parallelism # Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. # Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. +# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized +# subsequences, or set to 4 to split into four equal-sized subsequences. +# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details. sequence_parallel_degree: # Path to torch distx for optim 'adamw_anyprecision' diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd new file mode 100644 index 000000000..cb297c0e0 --- /dev/null +++ b/docs/sequence_parallelism.qmd @@ -0,0 +1,90 @@ +--- +title: Sequence Parallelism +description: Train with long sequences split across multiple GPUs. +--- + +# Sequence Parallelism + +Sequence parallelism is a technique that splits sequences across multiple GPUs, +allowing you to train with very long sequences that wouldn't fit on a single GPU. Each +GPU processes a different portion of the sequence, and the results are aggregated +through a ring communication pattern. + +## When to Use Sequence Parallelism + +Use sequence parallelism when: + +- You need to train with sequence lengths that don't fit into a single GPU's memory +- You have multiple GPUs available +- You're experiencing OOM (Out Of Memory) errors with long sequences + +## Configuration + +To enable sequence parallelism, add the following to your configuration file: + +```yaml +# Set to a divisor (> 1) of the number of GPUs available +sequence_parallel_degree: 4 # Split sequences across 4 GPUs +``` + +The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example: + +- With 8 GPUs, valid values would be 2, 4, or 8 +- With 4 GPUs, valid values would be 2 or 4 + +## Implementation Details + +When sequence parallelism is enabled: + +1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group +2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids +3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences +4. The trainer uses special ring communication patterns for attention operations + +## Requirements + +To use sequence parallelism, you need: + +- Multiple GPUs (at least 2) +- The `ring-flash-attn` package. Install with: + - `pip install axolotl[ring-flash-attn]` (preferred) + - `pip install ring-flash-attn>=0.1.4` + +## Limitations + +- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML) +- May have a small performance overhead due to communication between GPUs + +## Example + +```yaml +# Example config with sequence parallelism +base_model: meta-llama/Llama-3-8B-Instruct +sequence_len: 8192 +sequence_parallel_degree: 2 # Split each sequence into 4 parts +flash_attention: true # Required with sequence parallelism +... +``` + +This will train the Llama 3 8B model with 8K context length, with each sequence split +into 2 subsequences of length 4096 across 2 GPUs. + +## Sample Packing with Sequence Parallelism + +Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together: + +1. Samples are first packed together +2. The packed sequences are then divided across GPUs in the sequence parallel group +3. Position IDs are automatically adjusted to maintain proper relative positions + +## Effect on Batch Size + +When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because: + +- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence) +- The number of batches processed per step decreases + +For example: +- With 8 GPUs and no sequence parallelism: 8 different batches processed per step +- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs) +- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4 diff --git a/requirements.txt b/requirements.txt index 5b918739d..c8465d23f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -64,6 +64,3 @@ schedulefree==1.3.0 axolotl-contribs-lgpl==0.0.6 axolotl-contribs-mit==0.0.3 - -# for sequence parallelism -yunchang==0.6.0 diff --git a/setup.py b/setup.py index 08395a396..8b2f1b2a5 100644 --- a/setup.py +++ b/setup.py @@ -119,7 +119,8 @@ setup( ], }, extras_require={ - "flash-attn": ["flash-attn==2.7.4.post1", "ring-flash-attn>=0.1.4"], + "flash-attn": ["flash-attn==2.7.4.post1"], + "ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"], "deepspeed": [ "deepspeed==0.16.4", "deepspeed-kernels", diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d81ad6d41..8d5e41d34 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -590,7 +590,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): if self.args.sample_packing and not self.args.pretraining: train_dataset = train_dataset.remove_columns(["length"]) - if not (self.args.sample_packing and not self.args.pretraining): + if not self.args.sample_packing or self.args.pretraining: train_dataset = self._remove_unused_columns( train_dataset, description="training" ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b4dcd326a..4e6054df1 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -166,10 +166,22 @@ def setup_signal_handler( ) -def train_context_manager(enable=False) -> contextlib.AbstractContextManager: - """Configure CUDA SDP kernel settings if enabled.""" - if enable: +def train_context_manager( + flash_optimum: bool = False, +) -> contextlib.AbstractContextManager: + """ + Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`. + + Args: + flash_optimum: Whether to enable efficient backends for SDP attention. + + Returns: + Context manager for temporarily enabling efficient backends for SDP attention + if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise. + """ + if flash_optimum: return torch.backends.cuda.sdp_kernel( + # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... enable_flash=True, enable_math=True, enable_mem_efficient=True, @@ -190,7 +202,7 @@ def execute_training( resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ LOG.info("Starting trainer...") - context_manager = train_context_manager(cfg.flash_optimum) + context_manager = train_context_manager(flash_optimum=cfg.flash_optimum) with context_manager: trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 7dd402a3f..12c8b31d5 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -17,19 +17,14 @@ logger = logging.getLogger(__name__) def adjust_position_ids_for_slice( - position_ids: list | torch.Tensor, start_idx: int + position_ids: torch.Tensor, start_idx: int ) -> torch.Tensor: """ Adjust position IDs for a sliced sequence to maintain proper relative positions. - This handles the case where position IDs might not be contiguous due to sample packing. + This handles the case where position IDs might not be contiguous due to sample + packing. """ # Convert to tensor if not already - if not isinstance(position_ids, torch.Tensor): - position_ids = torch.tensor( - position_ids, - device=position_ids.device if hasattr(position_ids, "device") else None, - ) - # Find the boundaries between samples (where position_ids reset) adjusted_pos_ids = position_ids.clone()