updates
This commit is contained in:
@@ -66,7 +66,6 @@ from axolotl.utils.distributed import (
|
|||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
from axolotl.utils.ring_attn import register_ring_attn
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -549,6 +548,12 @@ class ModelLoader:
|
|||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|
||||||
if self.cfg.sequence_parallel_size > 1:
|
if self.cfg.sequence_parallel_size > 1:
|
||||||
|
from axolotl.utils.ring_attn import register_ring_attn
|
||||||
|
|
||||||
|
# Initialize ring attention for sequence parallelism if enabled.
|
||||||
|
# This must be done after model initialization but before the first forward pass,
|
||||||
|
# as it modifies the flash attention implementation to use ring communication
|
||||||
|
# patterns for efficient sequence-parallel training across multiple GPUs.
|
||||||
register_ring_attn(self.cfg.sequence_parallel_size)
|
register_ring_attn(self.cfg.sequence_parallel_size)
|
||||||
|
|
||||||
def patch_attention(self) -> None:
|
def patch_attention(self) -> None:
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
|
"""Ring attention group registration and utils."""
|
||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from ring_flash_attn import substitute_hf_flash_attn
|
from ring_flash_attn import substitute_hf_flash_attn
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
RING_ATTN_GROUP = None
|
RING_ATTN_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
@@ -9,17 +17,21 @@ def get_ring_attn_group():
|
|||||||
|
|
||||||
|
|
||||||
def set_ring_attn_group(ring_attn_group):
|
def set_ring_attn_group(ring_attn_group):
|
||||||
global RING_ATTN_GROUP
|
global RING_ATTN_GROUP # pylint: disable=global-statement
|
||||||
RING_ATTN_GROUP = ring_attn_group
|
RING_ATTN_GROUP = ring_attn_group
|
||||||
|
|
||||||
|
|
||||||
def register_ring_attn(sequence_parallel_size):
|
def register_ring_attn(sequence_parallel_size: int):
|
||||||
"""
|
"""
|
||||||
Create ring attention group and substitute flash attention with ring flash
|
Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
attention.
|
|
||||||
|
Args:
|
||||||
|
sequence_parallel_size: Sequence parallelism factor.
|
||||||
"""
|
"""
|
||||||
if sequence_parallel_size == 1:
|
LOG.info(
|
||||||
return
|
"Enabling ring attention sequence parallelism: "
|
||||||
|
f"each sequence will be processed across {sequence_parallel_size} GPUs"
|
||||||
|
)
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
assert world_size % sequence_parallel_size == 0, (
|
assert world_size % sequence_parallel_size == 0, (
|
||||||
|
|||||||
@@ -1104,6 +1104,17 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_sequence_parallel_config(cls, data):
|
||||||
|
if data.get("sequence_parallel_size") > 1:
|
||||||
|
if not data.get("flash_attention"):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attention: true must be set with sequence_parallel_size > 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
Reference in New Issue
Block a user