This commit is contained in:
Dan Saunders
2025-03-05 15:39:45 +00:00
parent f487910444
commit 14baaf6e0a
3 changed files with 35 additions and 7 deletions

View File

@@ -66,7 +66,6 @@ from axolotl.utils.distributed import (
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.ring_attn import register_ring_attn
LOG = logging.getLogger("axolotl")
@@ -549,6 +548,12 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_size > 1:
from axolotl.utils.ring_attn import register_ring_attn
# Initialize ring attention for sequence parallelism if enabled.
# This must be done after model initialization but before the first forward pass,
# as it modifies the flash attention implementation to use ring communication
# patterns for efficient sequence-parallel training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_size)
def patch_attention(self) -> None:

View File

@@ -1,6 +1,14 @@
"""Ring attention group registration and utils."""
import torch.distributed as dist
from accelerate.logging import get_logger
from ring_flash_attn import substitute_hf_flash_attn
from axolotl.logging_config import configure_logging
configure_logging()
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
@@ -9,17 +17,21 @@ def get_ring_attn_group():
def set_ring_attn_group(ring_attn_group):
global RING_ATTN_GROUP
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_size):
def register_ring_attn(sequence_parallel_size: int):
"""
Create ring attention group and substitute flash attention with ring flash
attention.
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_size: Sequence parallelism factor.
"""
if sequence_parallel_size == 1:
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_size} GPUs"
)
world_size = dist.get_world_size()
assert world_size % sequence_parallel_size == 0, (

View File

@@ -1104,6 +1104,17 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_sequence_parallel_config(cls, data):
if data.get("sequence_parallel_size") > 1:
if not data.get("flash_attention"):
raise ValueError(
"flash_attention: true must be set with sequence_parallel_size > 1"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""