diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index bc3a200d4..fd72cd6db 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -235,6 +235,9 @@ class AxolotlTrainer( self.accelerator.even_batches = False # Return unprepared dataloader if using sequence parallelism + # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation + # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., + # slice each batch along the sequence dimension). if self.args.sequence_parallel_degree > 1: return dataloader diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 9bcd5db57..3930c6cb3 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,34 +1,22 @@ """Module for Axolotl trainer sequence parallelism mixin""" import logging -from typing import Any -import torch import torch.distributed as dist -import torch.nn.functional as F from datasets import Dataset -from torch import nn from torch.utils.data import DistributedSampler, Sampler from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group LOG = logging.getLogger(__name__) -try: - from ring_flash_attn import update_ring_flash_attn_params -except ImportError: - # We pass silently here, but raise an ImportError in our Axolotl config validation - # if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed. - pass - class SequenceParallelMixin: """ Mixin class for sequence parallelism support in trainers. This mixin provides functionality for handling sequence parallelism, - including creating appropriate samplers, managing data partitioning, - and updating ring flash attention parameters during training. + specifically for creating appropriate data samplers. """ args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] @@ -99,84 +87,3 @@ class SequenceParallelMixin: return self._create_sequence_parallel_sampler( eval_dataset, shuffle=False, is_eval=True ) - - def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]): - """ - Calculate the cu_seqlens for the current forward pass and pass the value to - the substituted ring_flash_attn. This is accomplished by using the passed - `input_ids`. - - Args: - inputs: Current batch of inputs. - """ - # At this point, inputs should already be partitioned by the sequence - # parallel data collator - batch_size = inputs["input_ids"].shape[0] - seq_len = inputs["input_ids"].shape[1] - packed_seq_lens = [seq_len] * batch_size - - # Calculate the full sequence length across all GPUs in this SP group - total_seq_len = seq_len * self.args.sequence_parallel_degree - - cu_seqlens = torch.cumsum( - torch.tensor( - packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32 - ), - dim=-1, - dtype=torch.int32, - ) - cu_seqlens = F.pad( - F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len - ) - - update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group) - - def training_step( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor | Any], - num_items_in_batch: int | None = None, - ) -> torch.Tensor: - """ - Perform a training step on a batch of inputs. Overrides the - `transformers.trainer.Trainer` method to handle sequence parallelism if - enabled. - - Args: - model: Model to perform training step for. - inputs: Dictionary mapping. - """ - # Set up sequence parallelism for this step if enabled - if self.args.sequence_parallel_degree > 1: - self._update_ring_flash_attn_params(inputs) - - # Proceed with normal training step - return super().training_step(model, inputs, num_items_in_batch) # type: ignore - - def prediction_step( - self, - model: nn.Module, - inputs: dict[str, torch.Tensor | Any], - prediction_loss_only: bool, - ignore_keys: list[str] | None = None, - ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: - """ - Perform a prediction step on a batch of inputs. Overrides the - `transformers.trainer.Trainer` method to handle sequence parallelism if - enabled. - - Args: - model: Model to perform prediction step for. - inputs: Dictionary mapping of inputs. - prediction_loss_only: Whether to return only the loss. - ignore_keys: Keys to ignore in the inputs. - - Returns: - Tuple of (loss, logits, labels). - """ - # Set up sequence parallelism for this prediction step if enabled - if self.args.sequence_parallel_degree > 1: - self._update_ring_flash_attn_params(inputs) - - # Proceed with normal prediction step - return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index 6c9d0b429..30aa78f01 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -6,10 +6,12 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc their sequence parallel version of Flash Attention 2. """ +import torch import torch.distributed as dist from accelerate.logging import get_logger from axolotl.logging_config import configure_logging +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids configure_logging() LOG = get_logger(__name__) @@ -98,3 +100,27 @@ def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None substitute_hf_flash_attn( process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride ) + + +def update_ring_attn_params(batch: dict[str, torch.Tensor]): + """ + Calculate the cumulative sequence lengths for the current forward pass and pass the + value to the substituted `ring_flash_attn`. + + Args: + batch: A dictionary with a batch of data. May or may not contain `position_ids` + data; if not, we compute it. + """ + from ring_flash_attn import update_ring_flash_attn_params + + input_ids = batch["input_ids"] + position_ids = batch.get("position_ids") + if position_ids is None: + seq_len = input_ids.shape[1] + position_ids = torch.arange( + 0, seq_len, dtype=torch.long, device=input_ids.device + ).unsqueeze(0) + + cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) + update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 43496c7c8..4c6a4de11 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -96,7 +96,9 @@ def get_cu_seqlens(attn_mask): return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) -def get_cu_seqlens_from_pos_ids(position_ids): +def get_cu_seqlens_from_pos_ids( + position_ids: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: """generate a cumulative sequence length mask for flash attention using pos ids""" if len(position_ids.shape) == 1: position_ids = position_ids.unsqueeze(0) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 33bb4b4cc..ed445ae56 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -3,7 +3,6 @@ Data collators for axolotl to pad labels and position_ids for packed sequences. includes logic for handling sequence parallelism collation. """ -import logging from dataclasses import dataclass from typing import Any, Optional, Union @@ -13,46 +12,7 @@ import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -logger = logging.getLogger(__name__) - - -def adjust_position_ids_for_slice( - position_ids: torch.Tensor, start_idx: int -) -> torch.Tensor: - """ - Adjust position IDs for a sliced sequence to maintain proper relative positions. - This handles the case where position IDs might not be contiguous due to sample - packing. - """ - # Convert to tensor if not already - # Find the boundaries between samples (where position_ids reset) - adjusted_pos_ids = position_ids.clone() - - # Process each sequence in the batch - for i in range(position_ids.shape[0]): - seq = position_ids[i] - - # Find sample boundaries - boundaries = [] - for j in range(1, len(seq)): - if seq[j] < seq[j - 1]: - boundaries.append(j) - - # No need to adjust if there are no boundaries or this is a single sample - if not boundaries: - adjusted_pos_ids[i] = seq - start_idx - continue - - # Adjust each segment separately - prev_boundary = 0 - for boundary in boundaries: - adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx - prev_boundary = boundary - - # Last segment - adjusted_pos_ids[i, prev_boundary:] -= start_idx - - return adjusted_pos_ids +from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params @dataclass @@ -196,23 +156,20 @@ class DataCollatorForSeq2Seq: Returns: Sliced batch dictionary. """ - keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] + # Get local (start, end) for sequence parallelism slicing + total_seq_len = batch["input_ids"].shape[1] + slice_size = total_seq_len // self.local_world_size + start = self.local_rank * slice_size + end = start + slice_size + # Update params for ring attention calculation + update_ring_attn_params(batch=batch) + + # Slice batch for sequence parallel processing + keys_to_slice = ["input_ids", "attention_mask", "labels", "position_ids"] for key in keys_to_slice: if key in batch: - seq_len = batch[key].shape[1] - slice_size = seq_len // self.local_world_size - start_idx = self.local_rank * slice_size - end_idx = ( - start_idx + slice_size - if self.local_rank < self.local_world_size - 1 - else seq_len - ) - batch[key] = batch[key][:, start_idx:end_idx] - - # Special handling for position_ids - if key == "position_ids" and self.local_rank > 0: - batch[key] = adjust_position_ids_for_slice(batch[key], start_idx) + batch[key] = batch[key][:, start:end] return batch diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 0f9a3a1f9..4083fcc22 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1156,6 +1156,12 @@ class AxolotlInputConfig( "flash_attention: true must be set with sequence_parallel_degree > 1" ) + if not info.data["micro_batch_size"] == 1: + raise ValueError( + "micro_batch_size must be set to 1 " + "due to a `ring-flash-attn` requirement" + ) + try: import ring_flash_attn # noqa: F401 # pylint:disable=unused-import except ImportError as exception: @@ -1165,6 +1171,18 @@ class AxolotlInputConfig( "or `pip install ring-flash-attn>=0.1.4`." ) from exception + # TODO: monkeypatch / callback to average losses correctly across SP ranks + # / fix gradient scaling across SP ranks. Losses, grads should be scaled + # according to the proportion of non-padding tokens per rank. + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"sequence_parallel_degree={value}. Please note that logged losses may " + "differ slightly to the non-SP losses due to transformers Trainer " + "implementation details. Please see " + "https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." + ) + return value @model_validator(mode="before") diff --git a/tests/e2e/multigpu/test_sp.py b/tests/e2e/multigpu/test_sp.py new file mode 100644 index 000000000..2bd10beb5 --- /dev/null +++ b/tests/e2e/multigpu/test_sp.py @@ -0,0 +1,87 @@ +"""E2E tests for sequence parallelism""" + +import os +from pathlib import Path + +import yaml +from accelerate.test_utils import execute_subprocess_async +from transformers.testing_utils import get_torch_dist_unique_port + +from axolotl.utils.dict import DictDefault + +from ..utils import check_tensorboard + +os.environ["WANDB_DISABLED"] = "true" + + +class TestSequenceParallelism: + """Test case for training with sequence parallelism enabled""" + + def test_sequence_parallel_training(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "load_in_8bit": False, + "load_in_4bit": True, + "strict": False, + "sequence_len": 2048, + "adapter": "qlora", + "sample_packing": True, + "eval_sample_packing": True, + "pad_to_sequence_len": True, + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "lora_modules_to_save": ["embed_tokens", "lm_head"], + "special_tokens": {"pad_token": "<|endoftext|>"}, + "datasets": [ + { + "path": "tatsu-lab/alpaca", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 8, + "micro_batch_size": 1, + "gradient_accumulation_steps": 2, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "loss_watchdog_threshold": 5.0, + "loss_watchdog_patience": 3, + "bf16": "auto", + "warmup_steps": 1, + "saves_per_epoch": 1, + "logging_steps": 1, + "weight_decay": 0.0, + "use_tensorboard": True, + "sequence_parallel_degree": 2, + } + ) + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "accelerate", + "launch", + "--num-processes", + "2", + "--main_process_port", + f"{get_torch_dist_unique_port()}", + "-m", + "axolotl.cli.train", + str(Path(temp_dir) / "config.yaml"), + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high" + ) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 70beb8a54..1361a8522 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -12,7 +12,6 @@ from axolotl.monkeypatch.attention.ring_attn import ( get_ring_attn_group, set_ring_attn_group, ) -from axolotl.utils.collators.batching import adjust_position_ids_for_slice from axolotl.utils.dict import DictDefault @@ -48,33 +47,6 @@ def fixture_cfg(): return cfg -class TestSequenceParallelHelpers: - """Test helper functions used in sequence parallelism.""" - - def test_adjust_position_ids_for_slice(self, partial_state): - """Test position_ids adjustment for sequence slices.""" - # Create sample position_ids with multiple sequences - position_ids = torch.tensor( - [ - # First sequence with 2 samples - [0, 1, 2, 3, 4, 0, 1, 2, 3], - # Second sequence with 3 samples - [0, 1, 2, 0, 1, 2, 3, 0, 1], - ] - ) - - # Adjust as if this was the second slice (start_idx = 4) - adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4) - - # For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1] - # For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3] - expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4 - expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4 - - assert torch.all(adjusted[0] == expected_first_seq) - assert torch.all(adjusted[1] == expected_second_seq) - - class TestRingAttention: """Tests for the ring attention functionality."""