diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index 4329d9f13..7d733cfc1 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -51,6 +51,8 @@ NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1 def get_ring_attn_group() -> dist.ProcessGroup: """Getter for ring attention group on this rank.""" + if RING_ATTN_GROUP is None: + raise RuntimeError("register_ring_attn() not yet called") return RING_ATTN_GROUP @@ -69,8 +71,8 @@ def register_ring_attn( Args: sequence_parallel_degree: Sequence parallelism factor. - heads_k_stride: Sequence parallelism K head stride size. Passed - through to `ring_flash_attn.substitute_hf_flash_attn`. + heads_k_stride: Sequence parallelism K head stride size. Passed through to + `varlen_llama3` `ring_flash_attn` implementation. ring_attn_func: `ring_flash_attn` ring attention implemention. If sample packing is enabled, it must be a `varlen` function; otherwise, it must be a `batch` function. diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 90ab10e9f..46f722eeb 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -209,6 +209,7 @@ def execute_training( sequence_parallel_degree=cfg.sequence_parallel_degree, gradient_accumulation_steps=cfg.gradient_accumulation_steps, ring_attn_func=cfg.ring_attn_func, + heads_k_stride=cfg.heads_k_stride, ) ) diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py index 6e4f9bada..2ae93acad 100644 --- a/src/axolotl/utils/ctx_managers/sequence_parallel.py +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -12,6 +12,9 @@ from transformers.utils import ModelOutput from axolotl.monkeypatch.ring_attn.patch import ( get_ring_attn_group, + patch_prepare_data_loader, + patch_prepare_device_mesh, + register_ring_attn, update_ring_attn_params, ) from axolotl.utils.schemas.enums import RingAttnFunc @@ -169,6 +172,8 @@ class SequenceParallelContextManager: sequence_parallel_degree: Number of processes to split sequences over. gradient_accumulation_steps: Number of steps to accumulate gradients over. ring_attn_func: Which ring attention function to use. Currently unused. + heads_k_stride: Sequence parallelism K head stride size. Passed through to + `varlen_llama3` `ring_flash_attn` implementation. """ def __init__( @@ -177,14 +182,17 @@ class SequenceParallelContextManager: sequence_parallel_degree: int, gradient_accumulation_steps: int, ring_attn_func: RingAttnFunc, + heads_k_stride: int | None, ): self.models = models self.sequence_parallel_degree = sequence_parallel_degree self.gradient_accumulation_steps = gradient_accumulation_steps self.ring_attn_func = ring_attn_func - self.process_group = get_ring_attn_group() + self.heads_k_stride = heads_k_stride + self._register_ring_attn() - # Initialize sequence parallel group details + # Set distributed info for local rank + self.process_group = get_ring_attn_group() self.local_rank = dist.get_rank(self.process_group) self.local_world_size = dist.get_world_size(self.process_group) @@ -205,6 +213,33 @@ class SequenceParallelContextManager: ) def __enter__(self): + self._register_model_hooks() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + + # TODO(djsaunde): Un-patch attention and accelerate functions (low priority) + + def _register_ring_attn(self): + # Initialize ring attn for sequence parallelism + register_ring_attn( + sequence_parallel_degree=self.sequence_parallel_degree, + heads_k_stride=self.heads_k_stride, + ring_attn_func=self.ring_attn_func, + ) + + # Patches for accelerate functionality + patch_prepare_data_loader() + patch_prepare_device_mesh( + sequence_parallel_degree=self.sequence_parallel_degree + ) + + def _register_model_hooks(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): # Get parameter names from the model's forward function @@ -230,7 +265,7 @@ class SequenceParallelContextManager: # Forward post-hook to gather outputs def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: # Gather the sharded outputs - output = self.gather_outputs(output) + output = self._gather_outputs(output) # Remove padding if it was added if self.pad_len > 0: @@ -253,15 +288,7 @@ class SequenceParallelContextManager: model.register_forward_hook(sequence_parallel_post_hook) ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Remove all hooks - for handle in self.hook_handles: - handle.remove() - self.hook_handles = [] - - def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: + def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: """Gather sharded outputs from all ranks and reconstruct the full tensor.""" for key, value in output.items(): if isinstance(value, torch.Tensor) and value.dim() > 1: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6236f78e8..cd7499869 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -59,7 +59,6 @@ from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, ) -from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import get_chat_template_from_config @@ -681,27 +680,6 @@ class ModelLoader: patch_self_attn_lora(self.cfg) - if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1: - from axolotl.monkeypatch.ring_attn import ( - patch_prepare_data_loader, - patch_prepare_device_mesh, - register_ring_attn, - ) - - # Initialize ring attn for sequence parallelism. This must be done after - # model init but before the first forward pass, since it modifies flash - # attn to use ring comm for SP training across multiple GPUs. - if get_ring_attn_group() is None: # If already set, this is already patched - register_ring_attn( - sequence_parallel_degree=self.cfg.sequence_parallel_degree, - heads_k_stride=self.cfg.heads_k_stride, - ring_attn_func=self.cfg.ring_attn_func, - ) - patch_prepare_data_loader() - patch_prepare_device_mesh( - sequence_parallel_degree=self.cfg.sequence_parallel_degree - ) - def patch_attention(self) -> None: if hasattr(self.model_config, "model_type"): if self.model_config.model_type == "mllama" and self.cfg.flash_attention: diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 83faa779f..2b4d11b30 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -84,16 +84,16 @@ class TestRingAttention: def test_get_ring_attn_group_no_registration( self, mock_world_size, mock_rank, partial_state ): - """Test that get_ring_attn_group returns None when no group has been registered.""" + """Test that get_ring_attn_group raises RuntimeError when no group has been registered.""" # Setup mocks mock_world_size.return_value = 4 mock_rank.return_value = 0 - # Get the group without registration - group = get_ring_attn_group() - - # Verify that None was returned - assert group is None + # Verify that RuntimeError is raised when no group is registered + with pytest.raises( + RuntimeError, match="register_ring_attn\\(\\) not yet called" + ): + get_ring_attn_group() @patch("torch.distributed.new_group") @patch("torch.distributed.get_rank") @@ -323,8 +323,11 @@ class TestApplySequenceParallelism: lambda **kwargs: None, ) - def test_world_size_one(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test that function returns original batch when world size is 1.""" + mock_get_ring_attn_group.return_value = 0 + result, _, _ = apply_sequence_parallelism( batch=sequence_parallel_batch, local_rank=0, @@ -336,8 +339,11 @@ class TestApplySequenceParallelism: # Should return the original batch unchanged assert result == sequence_parallel_batch - def test_batch_ring_rank0(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test BATCH_RING sharding for rank 0 in a 2-process group.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) @@ -359,8 +365,11 @@ class TestApplySequenceParallelism: result["position_ids"], batch["position_ids"][:, : seq_len // 2] ) - def test_batch_ring_rank1(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch): """Test BATCH_RING sharding for rank 1 in a 2-process group.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) original_input_ids = batch["input_ids"].clone() @@ -419,8 +428,13 @@ class TestApplySequenceParallelism: # assert torch.equal(result_rank0["input_ids"], rank0_expected) # assert torch.equal(result_rank1["input_ids"], rank1_expected) - def test_partial_application(self, sequence_parallel_batch): + @patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group") + def test_partial_application( + self, mock_get_ring_attn_group, sequence_parallel_batch + ): """Test that we can create a partially applied version of the function.""" + mock_get_ring_attn_group.return_value = 0 + batch = sequence_parallel_batch original_input_ids = batch["input_ids"].clone()