SP context manager update (#2699)
* utilize accelerate prepare_data_loader with patching * lint * cleanup, fix * update to support DPO quirk * coderabbit commits, cleanup, remove dead code * fix * move ring attn patching to sp ctx manager * lint * lint * test fix * test fix
This commit is contained in:
@@ -51,6 +51,8 @@ NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
|
||||
|
||||
def get_ring_attn_group() -> dist.ProcessGroup:
|
||||
"""Getter for ring attention group on this rank."""
|
||||
if RING_ATTN_GROUP is None:
|
||||
raise RuntimeError("register_ring_attn() not yet called")
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
@@ -69,8 +71,8 @@ def register_ring_attn(
|
||||
|
||||
Args:
|
||||
sequence_parallel_degree: Sequence parallelism factor.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed
|
||||
through to `ring_flash_attn.substitute_hf_flash_attn`.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||
packing is enabled, it must be a `varlen` function; otherwise, it must be a
|
||||
`batch` function.
|
||||
|
||||
@@ -209,6 +209,7 @@ def execute_training(
|
||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@ from transformers.utils import ModelOutput
|
||||
|
||||
from axolotl.monkeypatch.ring_attn.patch import (
|
||||
get_ring_attn_group,
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
@@ -169,6 +172,8 @@ class SequenceParallelContextManager:
|
||||
sequence_parallel_degree: Number of processes to split sequences over.
|
||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||
`varlen_llama3` `ring_flash_attn` implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -177,14 +182,17 @@ class SequenceParallelContextManager:
|
||||
sequence_parallel_degree: int,
|
||||
gradient_accumulation_steps: int,
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
):
|
||||
self.models = models
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Initialize sequence parallel group details
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
@@ -205,6 +213,33 @@ class SequenceParallelContextManager:
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# Patches for accelerate functionality
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(
|
||||
sequence_parallel_degree=self.sequence_parallel_degree
|
||||
)
|
||||
|
||||
def _register_model_hooks(self):
|
||||
# Forward pre-hook to apply sequence parallelism
|
||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
@@ -230,7 +265,7 @@ class SequenceParallelContextManager:
|
||||
# Forward post-hook to gather outputs
|
||||
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
|
||||
# Gather the sharded outputs
|
||||
output = self.gather_outputs(output)
|
||||
output = self._gather_outputs(output)
|
||||
|
||||
# Remove padding if it was added
|
||||
if self.pad_len > 0:
|
||||
@@ -253,15 +288,7 @@ class SequenceParallelContextManager:
|
||||
model.register_forward_hook(sequence_parallel_post_hook)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Remove all hooks
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
self.hook_handles = []
|
||||
|
||||
def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||
for key, value in output.items():
|
||||
if isinstance(value, torch.Tensor) and value.dim() > 1:
|
||||
|
||||
@@ -59,7 +59,6 @@ from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
patch_for_multipack,
|
||||
)
|
||||
from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
@@ -681,27 +680,6 @@ class ModelLoader:
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
)
|
||||
|
||||
# Initialize ring attn for sequence parallelism. This must be done after
|
||||
# model init but before the first forward pass, since it modifies flash
|
||||
# attn to use ring comm for SP training across multiple GPUs.
|
||||
if get_ring_attn_group() is None: # If already set, this is already patched
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
|
||||
heads_k_stride=self.cfg.heads_k_stride,
|
||||
ring_attn_func=self.cfg.ring_attn_func,
|
||||
)
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(
|
||||
sequence_parallel_degree=self.cfg.sequence_parallel_degree
|
||||
)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
|
||||
@@ -84,16 +84,16 @@ class TestRingAttention:
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Get the group without registration
|
||||
group = get_ring_attn_group()
|
||||
|
||||
# Verify that None was returned
|
||||
assert group is None
|
||||
# Verify that RuntimeError is raised when no group is registered
|
||||
with pytest.raises(
|
||||
RuntimeError, match="register_ring_attn\\(\\) not yet called"
|
||||
):
|
||||
get_ring_attn_group()
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@@ -323,8 +323,11 @@ class TestApplySequenceParallelism:
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
def test_world_size_one(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
@@ -336,8 +339,11 @@ class TestApplySequenceParallelism:
|
||||
# Should return the original batch unchanged
|
||||
assert result == sequence_parallel_batch
|
||||
|
||||
def test_batch_ring_rank0(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
@@ -359,8 +365,11 @@ class TestApplySequenceParallelism:
|
||||
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
|
||||
)
|
||||
|
||||
def test_batch_ring_rank1(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
@@ -419,8 +428,13 @@ class TestApplySequenceParallelism:
|
||||
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
||||
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
||||
|
||||
def test_partial_application(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user