progress (messy :O)

This commit is contained in:
Dan Saunders
2025-06-12 18:54:41 +00:00
parent ae73123eae
commit aced809989
8 changed files with 333 additions and 199 deletions

View File

@@ -7,11 +7,13 @@ from __future__ import annotations
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Callable, Literal, Optional
from typing import Any, Callable, Literal, Optional
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
import datasets
import torch
from datasets import Dataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,

View File

@@ -17,7 +17,6 @@ from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel
from torch.distributed.tensor.experimental import context_parallel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
@@ -34,7 +33,7 @@ from axolotl.loaders import (
load_processor,
load_tokenizer,
)
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.ctx_managers.sequence_parallel import ContextParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
@@ -205,32 +204,24 @@ def execute_training(
)
if cfg.sequence_parallel_degree > 1:
if cfg.sdp_attention:
world_size = dist.get_world_size()
mesh_shape = (
world_size // cfg.sequence_parallel_degree,
cfg.sequence_parallel_degree,
)
mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
stack.enter_context(context_parallel(mesh=mesh))
else: # flash_attention
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
# Models to enter context parallel manager for
models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model)
stack.enter_context(
SequenceParallelContextManager(
models=models,
sequence_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
)
# Attention backend
backend = "sdp_attention" if cfg.sdp_attention else "flash_attention"
stack.enter_context(
ContextParallelContextManager(
models=models,
backend=backend,
context_parallel_degree=cfg.sequence_parallel_degree,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
)
)
LOG.info("Starting trainer...")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -3,4 +3,4 @@
# pylint: disable=unused-import
# flake8: noqa
from .sequence_parallel import SequenceParallelContextManager
from .sequence_parallel import ContextParallelContextManager

View File

@@ -2,11 +2,13 @@
import functools
import inspect
from typing import Literal
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.tensor.experimental import context_parallel
from torch.utils.hooks import RemovableHandle
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import ModelOutput
@@ -15,177 +17,41 @@ from axolotl.monkeypatch.ring_attn import (
patch_prepare_data_loader,
patch_prepare_device_mesh,
register_ring_attn,
update_ring_attn_params,
)
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
) -> tuple[dict[str, torch.Tensor], int, int]:
"""
Apply sequence parallelism slicing to a batch.
class ContextParallelContextManager:
"""Context manager for context parallelism operations.
Special handling is implemented for integer logits_to_keep, which indicates
to only keep the last N tokens in the sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the sequence parallel group.
local_world_size: World size of the sequence parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused, but
related to above TODO.
Returns:
tuple of:
- Batch dictionary with sliced tensors.
- The original sequence length before padding.
- The number of padding tokens added.
"""
original_seq_len = batch["input_ids"].size(1)
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
else:
# If position_ids aren't already in the batch, create them
batch["position_ids"] = torch.arange(
0,
original_seq_len,
dtype=torch.long,
device=batch["input_ids"].device,
).expand(batch["input_ids"].size(0), -1)
if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int):
logits_to_keep = batch["logits_to_keep"]
# Calculate which positions in the full sequence contain the last N tokens
start_position = max(0, original_seq_len - logits_to_keep)
chunk_size = original_seq_len // local_world_size
rank_start = local_rank * chunk_size
rank_end = rank_start + chunk_size
# Create a boolean mask tensor for this rank's chunk
mask = torch.zeros(
chunk_size,
dtype=torch.bool,
device=batch["input_ids"].device,
)
if rank_end > start_position:
# Calculate how many of the last N tokens fall within this rank's range
tokens_in_rank = min(rank_end, original_seq_len) - max(
rank_start, start_position
)
# Calculate where these tokens start in the local chunk
local_start_idx = max(0, start_position - rank_start)
# Set the appropriate positions in the mask to True
mask[local_start_idx : local_start_idx + tokens_in_rank] = True
# Replace the integer with the boolean mask
batch["logits_to_keep"] = mask
# Add padding to make sequence length divisible by local_world_size
total_seq_len = original_seq_len
pad_len = 0
divisor = min(local_world_size, 64)
if total_seq_len % divisor != 0:
pad_len = divisor - (total_seq_len % divisor)
# Apply padding to all relevant tensors
for key in batch:
if (
isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
# Create padding tensor
pad_value = -100 if key == "labels" else 0
padding = torch.full(
(batch[key].size(0), pad_len, *batch[key].shape[2:]),
pad_value,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=1)
if key == "logits_to_keep":
# Create padding tensor
padding = torch.ones(
1,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=0)
# Update the total sequence length after padding
total_seq_len = batch["input_ids"].size(1)
# Slice batch for sequence parallel
for key in batch:
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
continue
# Split in sequential fashion and grab this rank's chunk
if batch[key].size(1) == total_seq_len:
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif key == "logits_to_keep":
batch[key] = (
batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous()
)
# Handle num_items_in_batch
if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis.
batch["num_items_in_batch"] = (
batch["labels"] != -100
).sum() * gradient_accumulation_steps
return batch, original_seq_len, pad_len
class SequenceParallelContextManager:
"""Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
This class provides a context that will automatically apply context parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
across the context parallelism group using a post-forward hook.
Args:
models: List of models to apply sequence parallelism to pre- and post- forward
models: List of models to apply context parallelism to pre- and post- forward
hooks.
sequence_parallel_degree: Number of processes to split sequences over.
backend: Which attention backend to use.
context_parallel_degree: Number of processes to split sequences over.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused.
heads_k_stride: Sequence parallelism K head stride size. Passed through to
heads_k_stride: Context parallelism K head stride size. Passed through to
`varlen_llama3` `ring_flash_attn` implementation.
"""
def __init__(
self,
models: list[nn.Module],
sequence_parallel_degree: int,
models: list[PreTrainedModel],
backend: Literal["sdp_attention", "flash_attention"],
context_parallel_degree: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc,
heads_k_stride: int | None,
):
self.models = models
self.sequence_parallel_degree = sequence_parallel_degree
self.backend = backend
self.context_parallel_degree = context_parallel_degree
self.gradient_accumulation_steps = gradient_accumulation_steps
self.ring_attn_func = ring_attn_func
self.heads_k_stride = heads_k_stride
@@ -204,14 +70,34 @@ class SequenceParallelContextManager:
self.pad_len = 0
# Create a partially applied version of the apply_sequence_parallelism function
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
self.apply_context_parallelism = functools.partial(
apply_context_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
# SPDA CP initialization
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
enabled=self.context_parallel_degree > 1,
world_mesh=world_mesh,
model=model,
)
self.context_parallel_managers.append(ctx_manager)
def __enter__(self):
self._register_model_hooks()
@@ -226,22 +112,25 @@ class SequenceParallelContextManager:
# TODO(djsaunde): Un-patch attention and accelerate functions (low priority)
def _register_ring_attn(self):
# Initialize ring attn for sequence parallelism
register_ring_attn(
sequence_parallel_degree=self.sequence_parallel_degree,
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)
if self.backend == "flash_attention":
# Initialize ring attn for context parallelism
register_ring_attn(
sequence_parallel_degree=self.context_parallel_degree,
heads_k_stride=self.heads_k_stride,
ring_attn_func=self.ring_attn_func,
)
else:
stack.enter_context(context_parallel(mesh=mesh))
# Patches for accelerate functionality
patch_prepare_data_loader()
patch_prepare_device_mesh(
sequence_parallel_degree=self.sequence_parallel_degree
sequence_parallel_degree=self.context_parallel_degree
)
def _register_model_hooks(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
def cp_flash_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
@@ -257,13 +146,13 @@ class SequenceParallelContextManager:
# Apply sequence parallelism to updated kwargs
updated_kwargs, self.original_seq_len, self.pad_len = (
self.apply_sequence_parallelism(updated_kwargs)
self.apply_context_parallelism(updated_kwargs)
)
return remaining_args, updated_kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput:
def cp_flash_post_hook(_, __, output: ModelOutput) -> ModelOutput:
# Gather the sharded outputs
output = self._gather_outputs(output)
@@ -277,15 +166,19 @@ class SequenceParallelContextManager:
return output
def cp_sdpa_pre_hook(_, args, kwargs):
with self.context_parallel_managers[?](list(inputs.values())):
# Register both hooks
for model in self.models:
self.hook_handles.append(
model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
cp_flash_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
model.register_forward_hook(sequence_parallel_post_hook)
model.register_forward_hook(cp_flash_post_hook)
)
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:

View File

@@ -0,0 +1,145 @@
"""Utils for context parallel context manager."""
import torch
from axolotl.monkeypatch.ring_attn.patch import update_ring_attn_params
from axolotl.utils.schemas.enums import RingAttnFunc
# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this
# module. Currently, we just focus on batch ring and varlen llama3 for simplicity.
def apply_context_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
gradient_accumulation_steps: int,
ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument
) -> tuple[dict[str, torch.Tensor], int, iwnt]:
"""
Apply context parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates
to only keep the last N tokens in the input sequence during generation.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.).
local_rank: Local rank in the context parallel group.
local_world_size: World size of the context parallel group.
gradient_accumulation_steps: Number of steps to accumulate gradients over.
ring_attn_func: Which ring attention function to use. Currently unused, but
related to above TODO.
Returns:
tuple of:
- Batch dictionary with sliced tensors.
- The original sequence length before padding.
- The number of padding tokens added.
"""
original_seq_len = batch["input_ids"].size(1)
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
else:
# If position_ids aren't already in the batch, create them
batch["position_ids"] = torch.arange(
0,
original_seq_len,
dtype=torch.long,
device=batch["input_ids"].device,
).expand(batch["input_ids"].size(0), -1)
if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int):
logits_to_keep = batch["logits_to_keep"]
# Calculate which positions in the full sequence contain the last N tokens
start_position = max(0, original_seq_len - logits_to_keep)
chunk_size = original_seq_len // local_world_size
rank_start = local_rank * chunk_size
rank_end = rank_start + chunk_size
# Create a boolean mask tensor for this rank's chunk
mask = torch.zeros(
chunk_size,
dtype=torch.bool,
device=batch["input_ids"].device,
)
if rank_end > start_position:
# Calculate how many of the last N tokens fall within this rank's range
tokens_in_rank = min(rank_end, original_seq_len) - max(
rank_start, start_position
)
# Calculate where these tokens start in the local chunk
local_start_idx = max(0, start_position - rank_start)
# Set the appropriate positions in the mask to True
mask[local_start_idx : local_start_idx + tokens_in_rank] = True
# Replace the integer with the boolean mask
batch["logits_to_keep"] = mask
# Add padding to make sequence length divisible by local_world_size
total_seq_len = original_seq_len
pad_len = 0
divisor = min(local_world_size, 64)
if total_seq_len % divisor != 0:
pad_len = divisor - (total_seq_len % divisor)
# Apply padding to all relevant tensors
for key in batch:
if (
isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
# Create padding tensor
pad_value = -100 if key == "labels" else 0
padding = torch.full(
(batch[key].size(0), pad_len, *batch[key].shape[2:]),
pad_value,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=1)
if key == "logits_to_keep":
# Create padding tensor
padding = torch.ones(
1,
dtype=batch[key].dtype,
device=batch[key].device,
)
# Concatenate padding to the right side of the tensor
batch[key] = torch.cat([batch[key], padding], dim=0)
# Update the total sequence length after padding
total_seq_len = batch["input_ids"].size(1)
# Slice batch for context parallel
for key in batch:
if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1:
continue
# Split in sequential fashion and grab this rank's chunk
if batch[key].size(1) == total_seq_len:
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif key == "logits_to_keep":
batch[key] = (
batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous()
)
# Handle num_items_in_batch
if "num_items_in_batch" in batch:
# Approximation; this needed since num_items_in_batch may be counted across
# all samples in a gradient accumulated batch, not on a per-step basis.
batch["num_items_in_batch"] = (
batch["labels"] != -100
).sum() * gradient_accumulation_steps
return batch, original_seq_len, pad_len

View File

@@ -0,0 +1,103 @@
import contextlib
from typing import Callable, Generator, Optional, Union
import torch
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import BlockMask
from transformers import PreTrainedModel
def _get_sdpa_context() -> (
Callable[[Optional[Generator[None, None, None]]], Generator[None, None, None]]
):
"""
Creates a context manager to confine to flash/efficient/cuDNN attention backends.
Returns:
A context manager function that takes an optional context parallel context.
"""
@contextlib.contextmanager
def context(cp_context: Union[Generator[None, None, None], None] = None):
with contextlib.ExitStack() as stack:
if cp_context is not None:
stack.enter_context(
sdpa_kernel(
[
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
]
)
)
stack.enter_context(cp_context)
yield
return context
def get_context_parallel_manager(
*,
enabled: bool = False,
world_mesh: torch.distributed.DeviceMesh,
model: PreTrainedModel,
) -> Callable[[list[torch.Tensor]], Generator[None, None, None]]:
"""
Context manager for applying context parallelism to a model. In addition to applying the
standard context manager to patch SDPA and shard model inputs and buffers along the sequence
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
Args:
enabled: Whether context parallel is enabled. Default: False
world_mesh: Global device mesh.
model: Model to apply context parallelism to.
Returns:
A context manager applying context parallelism if enabled is True. Otherwise a context manager
disabling the math SDPA backend.
Raises:
ValueError: if enabled is True but world_mesh does not contain a "cp" dimension
"""
if enabled and "cp" not in world_mesh.mesh_dim_names:
raise ValueError(
"Context parallel is enabled but no context parallel device mesh is provided."
)
# TODO: context parallel for multimodal models requires extra work
if enabled and not isinstance(model, TransformerDecoder):
raise ValueError("Context parallel is only supported for text models")
# TODO: this is a hacky proxy for whether we use flex for chunked attention
# remove this once flex is supported
if enabled and any([layer.mask_mod is not None for layer in model.layers]):
raise ValueError("Context parallel with flex attention is not yet supported")
model_buffers = list(model.buffers())
@contextlib.contextmanager
def context(model_inputs: list[torch.Tensor]):
# Create context parallel context if enabled
cp_context = None
if enabled and any([isinstance(input, BlockMask) for input in model_inputs]):
raise ValueError(
"Context parallel with flex attention is not yet supported"
)
if enabled:
set_rotate_method("allgather")
cp_context = context_parallel(
world_mesh["cp"],
buffers=model_inputs + model_buffers,
buffer_seq_dims=[1] * len(model_inputs) + [0] * len(model_buffers),
no_restore_buffers=set(model_inputs),
)
# Create and enter the train context with the optional cp_context
sdpa_context = _get_sdpa_context()
with sdpa_context(cp_context):
yield
return context

View File

@@ -15,7 +15,7 @@ from axolotl.monkeypatch.ring_attn import (
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.ctx_managers.sequence_parallel import apply_context_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@@ -328,7 +328,7 @@ class TestApplySequenceParallelism:
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
@@ -347,7 +347,7 @@ class TestApplySequenceParallelism:
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
@@ -374,7 +374,7 @@ class TestApplySequenceParallelism:
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
@@ -440,7 +440,7 @@ class TestApplySequenceParallelism:
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
apply_context_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
@@ -466,7 +466,7 @@ class TestApplySequenceParallelism:
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism(
result, _, _ = apply_context_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,