From d92ac7a41dec5e4d645ed7bed039d98fa53fc1ee Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 24 Apr 2025 00:11:37 +0000 Subject: [PATCH] reorg --- .../core/trainers/mixins/sequence_parallel.py | 116 ++++++++++++++++- src/axolotl/core/trainers/sp.py | 119 ------------------ src/axolotl/train.py | 2 +- 3 files changed, 115 insertions(+), 122 deletions(-) delete mode 100644 src/axolotl/core/trainers/sp.py diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 3930c6cb3..d53f90ade 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,12 +1,21 @@ -"""Module for Axolotl trainer sequence parallelism mixin""" +""" +Module for Axolotl trainer sequence parallelism mixin and training context manager +""" import logging +import torch import torch.distributed as dist from datasets import Dataset +from torch import nn from torch.utils.data import DistributedSampler, Sampler +from torch.utils.hooks import RemovableHandle -from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group +from axolotl.monkeypatch.attention.ring_attn import ( + RingAttnFunc, + get_ring_attn_group, + update_ring_attn_params, +) LOG = logging.getLogger(__name__) @@ -87,3 +96,106 @@ class SequenceParallelMixin: return self._create_sequence_parallel_sampler( eval_dataset, shuffle=False, is_eval=True ) + + +class SequenceParallelContext: + """ + Context manager for sequence parallelism operations. + + This class provides a context that will automatically apply sequence parallelism + during model forward passes using a pre-forward hook. + """ + + def __init__( + self, + model: nn.Module, + sequence_parallel_degree: int, + ring_attn_func: RingAttnFunc, + ): + self.model = model + self.sequence_parallel_degree = sequence_parallel_degree + self.ring_attn_func = ring_attn_func + self.process_group = get_ring_attn_group() + + # Initialize sequence parallel group details + self.local_rank = dist.get_rank(self.process_group) + self.local_world_size = dist.get_world_size(self.process_group) + + # Will store hook handles for removal + self.hook_handle: RemovableHandle | None = None + + def __enter__(self): + # Define a forward pre-hook to apply sequence parallelism with kwargs support + def sequence_parallel_pre_hook( + module, args, kwargs + ): # pylint: disable=unused-argument + # Apply sequence parallelism to kwargs + kwargs = self.apply_sequence_parallelism(kwargs) + return args, kwargs + + # Register the pre-forward hook on the model + self.hook_handle = self.model.register_forward_pre_hook( + sequence_parallel_pre_hook, with_kwargs=True + ) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Remove the forward pre-hook + self.hook_handle.remove() + self.hook_handle = None + + def apply_sequence_parallelism( + self, batch: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """ + Apply sequence parallelism slicing to a batch. + + Args: + batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) + + Returns: + Sliced batch dictionary. + """ + # Update ring attention params if needed + if batch.get("position_ids") is not None: + update_ring_attn_params(position_ids=batch["position_ids"]) + + # Slice batch for sequence parallel processing + total_seq_len = batch["input_ids"].size(1) + for key in batch: + if ( + key in batch + and isinstance(batch[key], torch.Tensor) + and batch[key].dim() > 1 + and batch[key].size(1) == total_seq_len + ): + + if self.ring_attn_func in [ + RingAttnFunc.VARLEN_LLAMA3, + RingAttnFunc.BATCH_RING, + ]: + # Split in sequential fashion and grab this rank's chunk + batch[key] = ( + batch[key] + .chunk(self.local_world_size, dim=1)[self.local_rank] + .contiguous() + ) + elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: + chunks = batch[key].chunk(2 * self.local_world_size, dim=1) + + # Take rank's chunk and opposing chunk for zigzag pattern + selected_chunks = [ + chunks[self.local_rank], + chunks[2 * self.local_world_size - self.local_rank - 1], + ] + batch[key] = torch.cat(selected_chunks, dim=1).contiguous() + elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: + # Split into striped data and stack + tensor = torch.stack( + batch[key].split(self.local_world_size, dim=1), + dim=1, + ).transpose(1, 2) + batch[key] = tensor[:, self.local_rank].contiguous() + + return batch diff --git a/src/axolotl/core/trainers/sp.py b/src/axolotl/core/trainers/sp.py deleted file mode 100644 index ba6cb86a6..000000000 --- a/src/axolotl/core/trainers/sp.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Module for definition of sequence parallel context manager""" - -import logging - -import torch -import torch.distributed as dist -from torch import nn -from torch.utils.hooks import RemovableHandle - -from axolotl.monkeypatch.attention.ring_attn.patch import ( - RingAttnFunc, - get_ring_attn_group, - update_ring_attn_params, -) - -logger = logging.getLogger(__name__) - - -class SequenceParallelContext: - """ - Context manager for sequence parallelism operations. - - This class provides a context that will automatically apply sequence parallelism - during model forward passes using a pre-forward hook. - """ - - def __init__( - self, - model: nn.Module, - sequence_parallel_degree: int, - ring_attn_func: RingAttnFunc, - ): - self.model = model - self.sequence_parallel_degree = sequence_parallel_degree - self.ring_attn_func = ring_attn_func - self.process_group = get_ring_attn_group() - - # Initialize sequence parallel group details - self.local_rank = dist.get_rank(self.process_group) - self.local_world_size = dist.get_world_size(self.process_group) - - # Will store hook handles for removal - self.hook_handle: RemovableHandle | None = None - - def __enter__(self): - # Define a forward pre-hook to apply sequence parallelism with kwargs support - def sequence_parallel_pre_hook( - module, args, kwargs - ): # pylint: disable=unused-argument - # Apply sequence parallelism to kwargs - kwargs = self.apply_sequence_parallelism(kwargs) - return args, kwargs - - # Register the pre-forward hook on the model - self.hook_handle = self.model.register_forward_pre_hook( - sequence_parallel_pre_hook, with_kwargs=True - ) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Remove the forward pre-hook - self.hook_handle.remove() - self.hook_handle = None - - def apply_sequence_parallelism( - self, batch: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: - """ - Apply sequence parallelism slicing to a batch. - - Args: - batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) - - Returns: - Sliced batch dictionary. - """ - # Update ring attention params if needed - if batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=batch["position_ids"]) - - # Slice batch for sequence parallel processing - total_seq_len = batch["input_ids"].size(1) - for key in batch: - if ( - key in batch - and isinstance(batch[key], torch.Tensor) - and batch[key].dim() > 1 - and batch[key].size(1) == total_seq_len - ): - - if self.ring_attn_func in [ - RingAttnFunc.VARLEN_LLAMA3, - RingAttnFunc.BATCH_RING, - ]: - # Split in sequential fashion and grab this rank's chunk - batch[key] = ( - batch[key] - .chunk(self.local_world_size, dim=1)[self.local_rank] - .contiguous() - ) - elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - chunks = batch[key].chunk(2 * self.local_world_size, dim=1) - - # Take rank's chunk and opposing chunk for zigzag pattern - selected_chunks = [ - chunks[self.local_rank], - chunks[2 * self.local_world_size - self.local_rank - 1], - ] - batch[key] = torch.cat(selected_chunks, dim=1).contiguous() - elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: - # Split into striped data and stack - tensor = torch.stack( - batch[key].split(self.local_world_size, dim=1), - dim=1, - ).transpose(1, 2) - batch[key] = tensor[:, self.local_rank].contiguous() - - return batch diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 5c73647f0..02a69c909 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -26,7 +26,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder -from axolotl.core.trainers.sp import SequenceParallelContext +from axolotl.core.trainers.mixins.sequence_parallel import SequenceParallelContext from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed