SP dataloader patching + removing custom sampler / dataloader logic (#2686)

* utilize accelerate prepare_data_loader with patching

* lint

* cleanup, fix

* update to support DPO quirk

* small change

* coderabbit commits, cleanup, remove dead code

* quarto fix

* patch fix

* review comments

* moving monkeypatch up one level

* fix
This commit is contained in:
Dan Saunders
2025-05-21 11:20:20 -04:00
committed by GitHub
parent a27b909c5c
commit 6aa41740df
20 changed files with 304 additions and 477 deletions

View File

@@ -60,7 +60,6 @@ quartodoc:
- core.trainers.mixins.optimizer
- core.trainers.mixins.rng_state_loader
- core.trainers.mixins.scheduler
- core.trainers.mixins.sequence_parallel
- title: Context Managers
desc: Context managers for altering trainer behaviors
contents:

View File

@@ -87,20 +87,7 @@ We support sequence parallelism (SP) via the
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}

View File

@@ -41,7 +41,7 @@ When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations
## Requirements
@@ -67,9 +67,11 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
...
```

View File

@@ -2,7 +2,6 @@ base_model: Qwen/Qwen2.5-0.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
chat_template: qwen_25
rl: dpo
datasets:

View File

@@ -798,11 +798,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.kd_top_k_before_softmax
)
training_arguments_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_arguments_kwargs["ring_attn_func"] = self.cfg.ring_attn_func
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
@@ -1083,10 +1078,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:

View File

@@ -29,7 +29,6 @@ from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
@@ -40,9 +39,7 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -68,10 +65,6 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
# Initialize sequence parallelism if enabled
if self.args.sequence_parallel_degree > 1:
self._setup_sequence_parallel()
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
@@ -122,8 +115,8 @@ class AxolotlTrainer(
def _get_train_sampler(self) -> Sampler | None:
"""
Helper method to get the sampler for training. Handles cases for sequence
parallelism, sample packing, and curriculum sampling (sequential).
Helper method to get the sampler for training. Handles cases for sample packing
and curriculum sampling (sequential).
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
@@ -132,9 +125,7 @@ class AxolotlTrainer(
use_sample_packing = self.args.sample_packing and not self.args.pretraining
# Determine the base sampler first
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_train_sampler(self.train_dataset)
elif self.args.curriculum_sampling:
if self.args.curriculum_sampling:
base_sampler = SequentialSampler(self.train_dataset)
elif use_sample_packing:
base_sampler = RandomSampler(self.train_dataset)
@@ -153,8 +144,7 @@ class AxolotlTrainer(
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
"""
Helper method to get the sampler for evaluation. Handles sequence parallelism
and sample packing cases.
Helper method to get the sampler for evaluation. Handles sample packing case.
Returns:
If the dataset is non-empty, a sampler is returned, the type of which
@@ -168,9 +158,7 @@ class AxolotlTrainer(
)
# Determine the base sampler
if self.args.sequence_parallel_degree > 1:
base_sampler = self._sp_get_eval_sampler(eval_dataset)
elif use_multipack:
if use_multipack:
base_sampler = SequentialSampler(eval_dataset)
else:
return super()._get_eval_sampler(eval_dataset)
@@ -236,14 +224,6 @@ class AxolotlTrainer(
):
self.accelerator.even_batches = False
# Return unprepared dataloader if using sequence parallelism
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
# slice each batch along the sequence dimension).
if self.args.sequence_parallel_degree > 1:
return dataloader
# Otherwise prepare with accelerator
return self.accelerator.prepare_data_loader(dataloader)
def get_train_dataloader(self) -> DataLoader:
@@ -287,12 +267,7 @@ class AxolotlTrainer(
return dataloader
# Handle sample packing or sequence parallelism
if (
self.args.sample_packing
and self.args.eval_sample_packing is not False
or self.args.sequence_parallel_degree > 1
):
if self.args.sample_packing and self.args.eval_sample_packing is not False:
# Get appropriate data collator
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
@@ -302,17 +277,6 @@ class AxolotlTrainer(
if "length" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns(["length"])
# Handle dataset preprocessing for SP
if self.args.sequence_parallel_degree > 1:
if isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(
eval_dataset, description="evaluation"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator, description="evaluation"
)
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
batch_size = (
self.args.eval_batch_size

View File

@@ -1,31 +1,15 @@
"""
DPO trainer for axolotl
"""
"""DPO trainer for axolotl"""
import gc
import random
from functools import wraps
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union
import pandas as pd
import torch
import wandb
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.trainer_utils import EvalLoopOutput
from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
from trl.trainer.utils import log_table_to_comet_experiment
from trl import DPOTrainer
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
@@ -38,9 +22,7 @@ if is_sagemaker_mp_enabled():
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
"""Extend the base DPOTrainer for axolotl helpers."""
tag_names = ["axolotl", "dpo"]
@@ -85,8 +67,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
Overwrite the `push_to_hub` method in order to force-add the tags when pushing
the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub`
for more details.
"""
kwargs = sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
@@ -95,64 +78,6 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
# Apply the chat template if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features,
@@ -192,69 +117,3 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
gc.collect()
torch.cuda.empty_cache()
return loss
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(
range(num_samples), k=self.args.eval_batch_size
)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = (
self.generate_from_model_and_ref(self.model, random_batch)
)
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"],
policy_output_decoded,
ref_output_decoded,
)
],
)
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
# Base evaluation
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,
ignore_keys,
metric_key_prefix,
)
return initial_output

View File

@@ -43,7 +43,7 @@ from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group
from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import

View File

@@ -6,4 +6,3 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -1,87 +0,0 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
import torch.distributed as dist
from datasets import Dataset
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
)
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
This mixin provides functionality for handling sequence parallelism,
specifically for creating appropriate data samplers.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def _setup_sequence_parallel(self):
"""Set up sequence parallelism environment."""
self.ring_attn_group = get_ring_attn_group()
def _create_sequence_parallel_sampler(
self,
dataset: Dataset,
shuffle: bool = True,
is_eval: bool = False,
) -> DistributedSampler:
"""
Helper method to create sampler for sequence parallelism (SP).
We create a distributed sampler with rank equal to the SP group ID, which
means that all ranks in the SP group receive the same sample / set of samples
per training step. We also set the number of replicas equal to the number of
SP groups, which is a bit of a hack / unintended use, but works!
Args:
dataset: Dataset to sample from.
shuffle: Whether to shuffle the dataset.
is_eval: Whether we are creating a sampler for evaluation or training.
Returns:
Distributed sampler.
"""
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
return DistributedSampler(
dataset,
num_replicas=num_sp_groups,
rank=sp_group_id,
seed=self.args.seed if shuffle else None,
shuffle=shuffle,
drop_last=not is_eval,
)
def _sp_get_train_sampler(self, dataset) -> Sampler | None:
"""
Get a training sampler configured for sequence parallelism.
Args:
dataset: The training dataset
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
dataset,
shuffle=not self.args.curriculum_sampling,
)
def _sp_get_eval_sampler(self, eval_dataset) -> Sampler | None:
"""
Get an evaluation sampler configured for sequence parallelism.
Args:
eval_dataset: The evaluation dataset.
Returns:
Configured sequence parallel sampler.
"""
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)

View File

@@ -9,8 +9,6 @@ from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.utils.schemas.enums import RingAttnFunc
@dataclass
class AxolotlTrainingMixins:
@@ -216,17 +214,6 @@ class AxolotlTrainingMixins:
},
)
sequence_parallel_degree: Optional[int] = field(
default=1,
metadata={"help": "The number of workers to use in sequence parallelism"},
)
ring_attn_func: Optional[RingAttnFunc] = field(
default=None,
metadata={
"help": "The ring-flash-attn function to use in sequence parallelism"
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={

View File

@@ -1,11 +0,0 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)

View File

@@ -1,131 +0,0 @@
"""
Ring attention group registration and flash attention patching.
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
"""
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
def get_ring_attn_group() -> dist.ProcessGroup:
"""
Getter for ring attention group on this rank.
Returns:
The process group for ring attention for this rank.
"""
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""
Setter for ring attention group on this rank.
Args:
Process group for ring attention.
"""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
rank = dist.get_rank()
world_size = dist.get_world_size()
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Assign ranks to sequence parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func is RingAttnFunc.BATCH_RING:
from axolotl.monkeypatch.attention.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(),
ring_attn_func=ring_attn_func,
)
def update_ring_attn_params(position_ids: torch.Tensor | None):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.
Args:
position_ids: Optional tensor of position IDs (for sample packed data).
"""
from ring_flash_attn import update_ring_flash_attn_params
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())

View File

@@ -0,0 +1,22 @@
"""Init for ring attention monkeypatch module"""
# pylint: disable=unused-import
# flake8: noqa
from .patch import (
get_ring_attn_group,
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
set_ring_attn_group,
update_ring_attn_params,
)
__all__ = (
"get_ring_attn_group",
"patch_prepare_data_loader",
"patch_prepare_device_mesh",
"register_ring_attn",
"set_ring_attn_group",
"update_ring_attn_params",
)

View File

@@ -0,0 +1,223 @@
"""Ring attention group registration and flash attention patching.
Make use of the `ring-flash-attn` (https://github.com/zhuzilin/ring-flash-attention)
package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patch in
their sequence parallel version of Flash Attention 2.
We also provide some patches for accelerate functions to prepare the dataloader for
sequence parallelism training.
"""
import inspect
import accelerate
import torch
import torch.distributed as dist
from accelerate.logging import get_logger
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.schemas.enums import RingAttnFunc
LOG = get_logger(__name__)
RING_ATTN_GROUP = None
ORIGINAL_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size"""
NEW_PREPARE_DATALOADER_CODE = """ submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)"""
def get_ring_attn_group() -> dist.ProcessGroup:
"""Getter for ring attention group on this rank."""
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""Setter for ring attention group on this rank."""
global RING_ATTN_GROUP # pylint: disable=global-statement
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
ring_attn_func: RingAttnFunc | None,
):
"""Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
packing is enabled, it must be a `varlen` function; otherwise, it must be a
`batch` function.
"""
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
)
assert sequence_parallel_degree <= world_size, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must be less than or equal to world_size ({world_size})"
)
assert world_size % sequence_parallel_degree == 0, (
f"sequence_parallel_degree ({sequence_parallel_degree}) "
f"must evenly divide world_size ({world_size})"
)
# Assign ranks to sequence parallel groups
group_assignments = {}
for i in range(world_size // sequence_parallel_degree):
ring_attn_ranks = list(
range(
i * sequence_parallel_degree,
(i + 1) * sequence_parallel_degree,
)
)
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
# Track which GPUs are in which groups
for r in ring_attn_ranks:
group_assignments[r] = i
if rank in ring_attn_ranks:
set_ring_attn_group(group)
# Log the GPU group assignments
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func is RingAttnFunc.BATCH_RING:
from axolotl.monkeypatch.ring_attn.adapters.batch import (
substitute_hf_flash_attn,
)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(),
ring_attn_func=ring_attn_func,
)
def update_ring_attn_params(position_ids: torch.Tensor | None):
"""
Calculate the cumulative sequence lengths for the current forward pass and pass the
value to the substituted `ring_flash_attn`.
Args:
position_ids: Optional tensor of position IDs (for sample packed data).
"""
from ring_flash_attn import update_ring_flash_attn_params
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def patch_prepare_data_loader():
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
Raies:
RuntimeError: If source code to patch does not exist.
"""
original_fn = accelerate.data_loader.prepare_data_loader
original_source = inspect.getsource(original_fn)
if ORIGINAL_PREPARE_DATALOADER_CODE not in original_source:
raise RuntimeError(
"SP patch failed - target snippet not found. "
"Check accelerate's version or update the patch."
)
patched_source = original_source.replace(
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
)
# Create a new function from the patched source
namespace = {}
exec( # pylint: disable=exec-used # nosec B102
patched_source, accelerate.data_loader.__dict__, namespace
)
patched_function = namespace["prepare_data_loader"]
accelerate.data_loader.prepare_data_loader = patched_function
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
def patch_prepare_device_mesh(sequence_parallel_degree: int):
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
that includes sequence parallelism with the specified degree.
Args:
sequence_parallel_degree (int): The degree of sequence parallelism to use.
"""
def _prepare_device_mesh(self):
"""Prepare the device mesh for distributed training. The dataloader will
determine how to load data based on the device mesh.
"""
if self.state.torch_tp_plugin:
return self.state.torch_tp_plugin.torch_device_mesh
if (
self.distributed_type == accelerate.accelerator.DistributedType.DEEPSPEED
and hasattr(self.state, "ds_device_mesh")
):
return self.state.ds_device_mesh
# Create device mesh with sequence parallelism
world_size = dist.get_world_size()
mesh_shape = (
world_size // sequence_parallel_degree,
sequence_parallel_degree,
)
device_ids = list(range(world_size))
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
# parallelism" implementation naming
return dist.DeviceMesh(
"cuda",
torch.tensor(device_ids).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
# Replace the original method with our new method
# pylint: disable=protected-access
accelerate.accelerator.Accelerator._prepare_device_mesh = _prepare_device_mesh
LOG.info(
"Successfully patched Accelerator._prepare_device_mesh "
f"with sequence_parallel_degree={sequence_parallel_degree}"
)

View File

@@ -1,6 +1,7 @@
"""Module for Axolotl trainer sequence parallelism manager and utilities"""
import functools
import inspect
import torch
import torch.distributed as dist
@@ -9,7 +10,7 @@ from torch.utils.hooks import RemovableHandle
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
from axolotl.monkeypatch.attention.ring_attn.patch import (
from axolotl.monkeypatch.ring_attn.patch import (
get_ring_attn_group,
update_ring_attn_params,
)
@@ -206,12 +207,25 @@ class SequenceParallelContextManager:
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs and get original sequence length and padding info
kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(batch=kwargs)
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
return args, kwargs
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
# Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs)
)
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:

View File

@@ -59,6 +59,7 @@ from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.monkeypatch.ring_attn.patch import get_ring_attn_group
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
@@ -681,16 +682,25 @@ class ModelLoader:
patch_self_attn_lora(self.cfg)
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
from axolotl.monkeypatch.ring_attn import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
)
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
ring_attn_func=self.cfg.ring_attn_func,
)
if get_ring_attn_group() is None: # If already set, this is already patched
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
ring_attn_func=self.cfg.ring_attn_func,
)
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.cfg.sequence_parallel_degree
)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):

View File

@@ -10,7 +10,7 @@ import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.attention.ring_attn import (
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
@@ -313,13 +313,13 @@ class TestApplySequenceParallelism:
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)