From 14baaf6e0a39736eb88a225be26a214f3aa09f37 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 5 Mar 2025 15:39:45 +0000 Subject: [PATCH] updates --- src/axolotl/utils/models.py | 7 ++++++- src/axolotl/utils/ring_attn.py | 24 ++++++++++++++++++------ src/axolotl/utils/schemas/config.py | 11 +++++++++++ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 96ecdb4b1..3f2374608 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -66,7 +66,6 @@ from axolotl.utils.distributed import ( from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant -from axolotl.utils.ring_attn import register_ring_attn LOG = logging.getLogger("axolotl") @@ -549,6 +548,12 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_size > 1: + from axolotl.utils.ring_attn import register_ring_attn + + # Initialize ring attention for sequence parallelism if enabled. + # This must be done after model initialization but before the first forward pass, + # as it modifies the flash attention implementation to use ring communication + # patterns for efficient sequence-parallel training across multiple GPUs. register_ring_attn(self.cfg.sequence_parallel_size) def patch_attention(self) -> None: diff --git a/src/axolotl/utils/ring_attn.py b/src/axolotl/utils/ring_attn.py index ddd70cf98..5552a9047 100644 --- a/src/axolotl/utils/ring_attn.py +++ b/src/axolotl/utils/ring_attn.py @@ -1,6 +1,14 @@ +"""Ring attention group registration and utils.""" + import torch.distributed as dist +from accelerate.logging import get_logger from ring_flash_attn import substitute_hf_flash_attn +from axolotl.logging_config import configure_logging + +configure_logging() +LOG = get_logger(__name__) + RING_ATTN_GROUP = None @@ -9,17 +17,21 @@ def get_ring_attn_group(): def set_ring_attn_group(ring_attn_group): - global RING_ATTN_GROUP + global RING_ATTN_GROUP # pylint: disable=global-statement RING_ATTN_GROUP = ring_attn_group -def register_ring_attn(sequence_parallel_size): +def register_ring_attn(sequence_parallel_size: int): """ - Create ring attention group and substitute flash attention with ring flash - attention. + Create ring attention group and substitute flash attn with ring flash attn. + + Args: + sequence_parallel_size: Sequence parallelism factor. """ - if sequence_parallel_size == 1: - return + LOG.info( + "Enabling ring attention sequence parallelism: " + f"each sequence will be processed across {sequence_parallel_size} GPUs" + ) world_size = dist.get_world_size() assert world_size % sequence_parallel_size == 0, ( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f57586a82..89b3612fb 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1104,6 +1104,17 @@ class AxolotlInputConfig( return data + @model_validator(mode="before") + @classmethod + def check_sequence_parallel_config(cls, data): + if data.get("sequence_parallel_size") > 1: + if not data.get("flash_attention"): + raise ValueError( + "flash_attention: true must be set with sequence_parallel_size > 1" + ) + + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options"""