From 69aeae80edf60a50255104dc9e0760e81fd7e7f5 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 23 Apr 2025 23:19:52 +0000 Subject: [PATCH] updates --- src/axolotl/core/trainers/base.py | 9 +- src/axolotl/core/trainers/sp.py | 312 +++++++--------------------- src/axolotl/train.py | 1 + src/axolotl/utils/schemas/config.py | 64 +++--- 4 files changed, 112 insertions(+), 274 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index fd72cd6db..54fc5d902 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -371,13 +371,20 @@ class AxolotlTrainer( num_items_in_batch=num_items_in_batch, ) - return super().compute_loss( + loss = super().compute_loss( model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch, ) + # This is needed due to details of our sequence parallel implementation; the HF + # trainer averages the loss over the full sequence length depite our splitting + # the data along the sequence dimension. + loss *= self.args.sequence_parallel_degree + + return loss + @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} diff --git a/src/axolotl/core/trainers/sp.py b/src/axolotl/core/trainers/sp.py index 5f384bad8..41e31a6ed 100644 --- a/src/axolotl/core/trainers/sp.py +++ b/src/axolotl/core/trainers/sp.py @@ -1,11 +1,11 @@ -import contextlib -import functools +"""Module for definition of sequence parallel context manager""" + +import inspect import logging -from typing import Dict, List, Optional, Set import torch import torch.distributed as dist -import torch.nn as nn +from torch import nn from torch.utils.hooks import RemovableHandle from axolotl.monkeypatch.attention.ring_attn.patch import ( @@ -22,17 +22,16 @@ class SequenceParallelContext: Context manager for sequence parallelism operations. This class provides a context that will automatically apply sequence parallelism - during model forward passes. + during model forward passes using pre-forward hooks. """ - # Keep track of active contexts to support nested contexts - _active_contexts = [] - def __init__( self, + model: nn.Module, sequence_parallel_degree: int, ring_attn_func: RingAttnFunc, ): + self.model = model self.sequence_parallel_degree = sequence_parallel_degree self.ring_attn_func = ring_attn_func self.process_group = get_ring_attn_group() @@ -42,9 +41,8 @@ class SequenceParallelContext: self.local_world_size = 1 self.active = False - # Will store original methods for restoration - self._original_module_forward = None - self._hooks: List[RemovableHandle] = [] + # Will store hook handles for removal + self.hook_handles: list[RemovableHandle] = [] if self.sequence_parallel_degree > 1: if self.process_group is None: @@ -55,75 +53,64 @@ class SequenceParallelContext: def __enter__(self): self.active = True - SequenceParallelContext._active_contexts.append(self) - # Store the original forward method - if self._original_module_forward is None: - self._original_module_forward = nn.Module.forward + # Define a pre-forward hook to apply sequence parallelism with kwargs support + def sequence_parallel_pre_hook(module, args, kwargs): + if not self.active or self.sequence_parallel_degree <= 1: + return None - # Replace nn.Module.forward with our sequence parallel version - nn.Module.forward = self._make_sequence_parallel_forward(nn.Module.forward) + # Apply sequence parallelism to kwargs + if kwargs: + transformed_kwargs = self.apply_sequence_parallelism(kwargs) + return args, transformed_kwargs + + # If no kwargs but we have args, try to convert them to kwargs + if args and not kwargs: + try: + signature = inspect.signature(module.forward) + param_names = list(signature.parameters.keys())[1:] # Skip 'self' + + # Create kwargs from args + new_kwargs = {} + for i, arg in enumerate(args): + if i < len(param_names): + new_kwargs[param_names[i]] = arg + else: + # If we can't map all args, don't transform + return None + + # Apply sequence parallelism to the new kwargs + transformed_kwargs = self.apply_sequence_parallelism(new_kwargs) + + # Return empty args and the transformed kwargs + return (), transformed_kwargs + except (ValueError, TypeError): + # If conversion fails, don't transform + return None + + # If no args and no kwargs, nothing to transform + return None + + # Register the pre-forward hook on the model + self.hook_handles.append( + self.model.register_forward_pre_hook( + sequence_parallel_pre_hook, with_kwargs=True + ) + ) return self def __exit__(self, exc_type, exc_val, exc_tb): self.active = False - # Only restore original forward if this is the last active context - if ( - SequenceParallelContext._active_contexts - and SequenceParallelContext._active_contexts[-1] == self - ): - SequenceParallelContext._active_contexts.pop() - - # Restore original forward method - if self._original_module_forward is not None: - nn.Module.forward = self._original_module_forward - self._original_module_forward = None - - # Remove any hooks we added - for hook in self._hooks: - hook.remove() - self._hooks = [] - - def _make_sequence_parallel_forward(self, original_forward): - """Create a wrapped forward method that applies sequence parallelism.""" - - @functools.wraps(original_forward) - def sequence_parallel_forward(module_self, *args, **kwargs): - # Skip sequence parallelism if not active or degree is 1 - if not self.active or self.sequence_parallel_degree <= 1: - return original_forward(module_self, *args, **kwargs) - - # Convert args to kwargs if needed - if args: - # Try to map positional args to kwargs based on the forward method signature - import inspect - - try: - signature = inspect.signature(original_forward) - param_names = list(signature.parameters.keys())[1:] # Skip 'self' - for i, arg in enumerate(args): - if i < len(param_names): - kwargs[param_names[i]] = arg - else: - # If we can't map all args, fall back to original forward - return original_forward(module_self, *args, **kwargs) - except (ValueError, TypeError): - # If we can't get the signature, just use the original forward - return original_forward(module_self, *args, **kwargs) - - # Apply sequence parallelism to the inputs - kwargs = self.apply_sequence_parallelism(kwargs) - - # Call the original forward with modified kwargs - return original_forward(module_self, **kwargs) - - return sequence_parallel_forward + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] def apply_sequence_parallelism( - self, batch: Dict[str, torch.Tensor] - ) -> Dict[str, torch.Tensor]: + self, batch: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: """ Apply sequence parallelism slicing to a batch. @@ -136,198 +123,45 @@ class SequenceParallelContext: if self.sequence_parallel_degree <= 1 or not self.active: return batch - # Make a copy of the batch to avoid modifying the original - new_batch = dict(batch) - - # Get total sequence length from input_ids or inputs_embeds - if "input_ids" in new_batch: - total_seq_len = new_batch["input_ids"].size(1) - elif "inputs_embeds" in new_batch: - total_seq_len = new_batch["inputs_embeds"].size(1) - else: - # If we can't determine sequence length, return the batch as is - return new_batch - # Update ring attention params if needed - if new_batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=new_batch["position_ids"]) + if batch.get("position_ids") is not None: + update_ring_attn_params(position_ids=batch["position_ids"]) # Slice batch for sequence parallel processing - for key in new_batch: + total_seq_len = batch["input_ids"].size(1) + for key in batch: if ( - key in new_batch - and isinstance(new_batch[key], torch.Tensor) - and new_batch[key].dim() > 1 - and new_batch[key].size(1) == total_seq_len + key in batch + and isinstance(batch[key], torch.Tensor) + and batch[key].dim() > 1 + and batch[key].size(1) == total_seq_len ): if self.ring_attn_func in [ RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING, ]: - new_batch[key] = ( - new_batch[key] + # Split in sequential fashion and grab this rank's chunk + batch[key] = ( + batch[key] .chunk(self.local_world_size, dim=1)[self.local_rank] .contiguous() ) elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - chunks = new_batch[key].chunk(2 * self.local_world_size, dim=1) + chunks = batch[key].chunk(2 * self.local_world_size, dim=1) # Take rank's chunk and opposing chunk for zigzag pattern selected_chunks = [ chunks[self.local_rank], chunks[2 * self.local_world_size - self.local_rank - 1], ] - new_batch[key] = torch.cat(selected_chunks, dim=1).contiguous() + batch[key] = torch.cat(selected_chunks, dim=1).contiguous() elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: # Split into striped data and stack tensor = torch.stack( - new_batch[key].split(self.local_world_size, dim=1), + batch[key].split(self.local_world_size, dim=1), dim=1, ).transpose(1, 2) - new_batch[key] = tensor[:, self.local_rank].contiguous() + batch[key] = tensor[:, self.local_rank].contiguous() - return new_batch - - -@contextlib.contextmanager -def sequence_parallel( - sequence_parallel_degree: int = 1, - process_group: Optional[dist.ProcessGroup] = None, - ring_attn_func: Optional[RingAttnFunc] = None, - buffers: Optional[List[torch.Tensor]] = None, - buffer_seq_dims: Optional[List[int]] = None, - no_restore_buffers: Optional[Set[torch.Tensor]] = None, -): - """ - Context manager for sequence parallelism. - - This context manager will apply sequence parallelism to model inputs - for all forward passes within its scope. - - Args: - sequence_parallel_degree: The degree of sequence parallelism. - process_group: The process group to use for communication. Default is the world group. - ring_attn_func: The ring attention function to use. - buffers: Optional list of buffers to shard (e.g., input tensors, position embeddings). - buffer_seq_dims: Sequence dimensions for each buffer to shard. - no_restore_buffers: Optional set of buffers that don't need to be restored. - - Yields: - The sequence parallel context. - """ - buffers = [] if buffers is None else buffers - buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims - no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers - - if len(buffers) != len(buffer_seq_dims): - raise ValueError( - "`buffer_seq_dims` must have the same number of elements as `buffers`." - ) - - # Save original buffer states - original_buffers = [] - for buffer in buffers: - if buffer in no_restore_buffers: - original_buffers.append(None) - else: - original_buffers.append(buffer.clone()) - - # Create context - context = SequenceParallelContext( - sequence_parallel_degree=sequence_parallel_degree, - process_group=process_group, - ring_attn_func=ring_attn_func, - ) - - # Apply sequence parallelism to buffers if provided - if buffers and buffer_seq_dims: - for i, (buffer, dim) in enumerate(zip(buffers, buffer_seq_dims)): - if context.sequence_parallel_degree > 1: - # Get local shard - sharded_tensor = context.apply_sequence_parallelism( - {"tensor": buffer.unsqueeze(0)} - )["tensor"].squeeze(0) - - # Resize and copy in-place - buffer.resize_(sharded_tensor.shape) - buffer.copy_(sharded_tensor) - - try: - # Enter the context - with context: - yield context - finally: - # Restore original buffer states - for buffer, original in zip(buffers, original_buffers): - if original is not None: - buffer.resize_(original.shape) - buffer.copy_(original) - - -def enable_sequence_parallel_for_module( - module: nn.Module, - sequence_parallel_degree: int = 1, - process_group: Optional[dist.ProcessGroup] = None, - ring_attn_func: Optional[RingAttnFunc] = None, -): - """ - Enable sequence parallelism for a specific module. - - This function wraps the module's forward method to automatically apply - sequence parallelism without using a context manager. - - Args: - module: The module to enable sequence parallelism for. - sequence_parallel_degree: The degree of sequence parallelism. - process_group: The process group to use for communication. - ring_attn_func: The ring attention function to use. - - Returns: - The module with sequence parallelism enabled. - """ - # Create a context for this module - context = SequenceParallelContext( - sequence_parallel_degree=sequence_parallel_degree, - process_group=process_group, - ring_attn_func=ring_attn_func, - ) - - # Save the original forward method - original_forward = module.forward - - @functools.wraps(original_forward) - def sequence_parallel_forward(*args, **kwargs): - # Activate the context - context.active = True - - # Convert args to kwargs if needed - if args: - import inspect - - try: - signature = inspect.signature(original_forward) - param_names = list(signature.parameters.keys()) - for i, arg in enumerate(args): - if i < len(param_names): - kwargs[param_names[i]] = arg - else: - return original_forward(*args, **kwargs) - except (ValueError, TypeError): - return original_forward(*args, **kwargs) - - # Apply sequence parallelism to inputs - kwargs = context.apply_sequence_parallelism(kwargs) - - # Call original forward with modified inputs - result = original_forward(**kwargs) - - # Deactivate the context - context.active = False - - return result - - # Replace the forward method - module.forward = sequence_parallel_forward - - return module + return batch diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 3530654a5..f098be475 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -197,6 +197,7 @@ def execute_training( trainer.train(resume_from_checkpoint=resume_from_checkpoint) elif cfg.sequence_parallel_degree > 1: with SequenceParallelContext( + model=trainer.model, sequence_parallel_degree=cfg.sequence_parallel_degree, ring_attn_func=cfg.ring_attn_func, ): diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 732ae60cf..07bedbbd7 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -18,6 +18,7 @@ from pydantic import ( ) from transformers.utils.import_utils import is_torch_npu_available +from axolotl.utils.distributed import is_main_process from axolotl.utils.schemas.datasets import ( DatasetConfig, DPODataset, @@ -1149,22 +1150,17 @@ class AxolotlInputConfig( return data - @field_validator("sequence_parallel_degree", mode="after") - @classmethod - def check_sequence_parallel_degree(cls, value, info): - if not value: - value = 1 - - if value > 1: - if not info.data.get("flash_attention"): + @model_validator(mode="after") + def check_sequence_parallel_degree(self): + if not self.sequence_parallel_degree: + self.sequence_parallel_degree = 1 + elif self.sequence_parallel_degree > 1: + if not self.flash_attention: raise ValueError( "flash_attention: true must be set with sequence_parallel_degree > 1" ) - if ( - info.data.get("sample_packing") - and not info.data["micro_batch_size"] == 1 - ): + if self.sample_packing and not self.micro_batch_size: raise ValueError( "micro_batch_size must be set to 1 when sample_packing is enabled" "due to a `ring-flash-attn` requirement" @@ -1182,44 +1178,44 @@ class AxolotlInputConfig( # TODO: monkeypatch / callback to average losses correctly across SP ranks # / fix gradient scaling across SP ranks. Losses, grads should be scaled # according to the proportion of non-padding tokens per rank. - LOG.warning( - "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={value}. Please note that logged losses may " - "differ slightly to the non-SP losses due to transformers Trainer " - "implementation details. Please see " - "https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details." - ) + if is_main_process(): + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"sequence_parallel_degree={self.sequence_parallel_degree}. " + "Please note that logged losses may differ slightly to the non-SP " + "losses due to transformers Trainer implementation details. " + "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." + ) - return value + return self - @field_validator("ring_attn_func", mode="after") - @classmethod - def check_ring_attn_func(cls, value, info): - if not info.data.get("sequence_parallel_degree", 1) > 1: - return value + @model_validator(mode="after") + def validate_ring_attn_func(self): + if self.sequence_parallel_degree == 1: + return self + # Your validation logic for ring_attn_func from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc - if value is not None: - # Set the ring attention function if passed in config + if self.ring_attn_func is not None: valid_funcs = list(RingAttnFunc) - if value in valid_funcs: - value = RingAttnFunc(value) + if self.ring_attn_func in valid_funcs: + self.ring_attn_func = RingAttnFunc(self.ring_attn_func) else: raise ValueError( - f"ring_attn_func: {value} must be one of {valid_funcs}" + f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}" ) else: # Default ring attention function selection - sample_packing = info.data.get("sample_packing") - value = ( + sample_packing = getattr(self, "sample_packing", False) + self.ring_attn_func = ( RingAttnFunc.VARLEN_LLAMA3 if sample_packing else RingAttnFunc.BATCH_RING ) - return value + return self @model_validator(mode="before") @classmethod