updates
This commit is contained in:
@@ -66,7 +66,6 @@ from axolotl.utils.distributed import (
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
from axolotl.utils.ring_attn import register_ring_attn
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -549,6 +548,12 @@ class ModelLoader:
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_size > 1:
|
||||
from axolotl.utils.ring_attn import register_ring_attn
|
||||
|
||||
# Initialize ring attention for sequence parallelism if enabled.
|
||||
# This must be done after model initialization but before the first forward pass,
|
||||
# as it modifies the flash attention implementation to use ring communication
|
||||
# patterns for efficient sequence-parallel training across multiple GPUs.
|
||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
"""Ring attention group registration and utils."""
|
||||
|
||||
import torch.distributed as dist
|
||||
from accelerate.logging import get_logger
|
||||
from ring_flash_attn import substitute_hf_flash_attn
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
RING_ATTN_GROUP = None
|
||||
|
||||
|
||||
@@ -9,17 +17,21 @@ def get_ring_attn_group():
|
||||
|
||||
|
||||
def set_ring_attn_group(ring_attn_group):
|
||||
global RING_ATTN_GROUP
|
||||
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||
RING_ATTN_GROUP = ring_attn_group
|
||||
|
||||
|
||||
def register_ring_attn(sequence_parallel_size):
|
||||
def register_ring_attn(sequence_parallel_size: int):
|
||||
"""
|
||||
Create ring attention group and substitute flash attention with ring flash
|
||||
attention.
|
||||
Create ring attention group and substitute flash attn with ring flash attn.
|
||||
|
||||
Args:
|
||||
sequence_parallel_size: Sequence parallelism factor.
|
||||
"""
|
||||
if sequence_parallel_size == 1:
|
||||
return
|
||||
LOG.info(
|
||||
"Enabling ring attention sequence parallelism: "
|
||||
f"each sequence will be processed across {sequence_parallel_size} GPUs"
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
assert world_size % sequence_parallel_size == 0, (
|
||||
|
||||
@@ -1104,6 +1104,17 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sequence_parallel_config(cls, data):
|
||||
if data.get("sequence_parallel_size") > 1:
|
||||
if not data.get("flash_attention"):
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_size > 1"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
Reference in New Issue
Block a user