This commit is contained in:
Dan Saunders
2025-06-14 17:40:21 +00:00
parent 7a88de4fa8
commit f8f87321bd
6 changed files with 75 additions and 64 deletions

View File

@@ -7,13 +7,11 @@ from __future__ import annotations
import os
from collections import defaultdict
from functools import partial, wraps
from typing import Any, Callable, Literal, Optional
from typing import Callable, Literal, Optional
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
import datasets
import torch
from datasets import Dataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,

View File

@@ -11,7 +11,6 @@ from pathlib import Path
from typing import Any, Dict
import torch
import torch.distributed as dist
import transformers.modelcard
from accelerate.utils import save_fsdp_model
from datasets import Dataset
@@ -33,7 +32,7 @@ from axolotl.loaders import (
load_processor,
load_tokenizer,
)
from axolotl.utils.ctx_managers.context_parallel import ContextParallelContextManager
from axolotl.utils.ctx_managers import ContextParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except

View File

@@ -1,6 +1,5 @@
"""Init for context manager submodule."""
from .context_parallel.manager import ContextParallelContextManager
__all__ = ["ContextParallelContextManager"]

View File

@@ -35,11 +35,11 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
import contextlib
from typing import Callable, Generator, Optional, Union
from axolotl.utils.dict import DictDefault
import torch
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask
from transformers import PreTrainedModel
@@ -85,7 +85,6 @@ def get_context_parallel_manager(
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
Args:
enabled: Whether context parallel is enabled. Default: False
world_mesh: Global device mesh.
model: Model to apply context parallelism to.
@@ -102,12 +101,8 @@ def get_context_parallel_manager(
"Context parallel is enabled but no context parallel device mesh is provided."
)
# TODO: context parallel for multimodal models requires extra work
if not isinstance(model, TransformerDecoder):
raise ValueError("Context parallel is only supported for text models")
# TODO: this is a hacky proxy for whether we use flex for chunked attention
# remove this once flex is supported
if any([layer.mask_mod is not None for layer in model.layers]):
raise ValueError("Context parallel with flex attention is not yet supported")
# if not isinstance(model, TransformerDecoder):
# raise ValueError("Context parallel is only supported for text models")
model_buffers = list(model.buffers())
@contextlib.contextmanager
@@ -132,4 +127,4 @@ def get_context_parallel_manager(
with sdpa_context(cp_context):
yield
return context
return context

View File

@@ -17,8 +17,13 @@ from axolotl.monkeypatch.ring_attn import (
patch_prepare_device_mesh,
register_ring_attn,
)
from axolotl.utils.ctx_managers.context_parallel.utils import AllGatherWithGrad, apply_context_parallelism
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
from axolotl.utils.ctx_managers.context_parallel.distributed import (
get_context_parallel_manager,
)
from axolotl.utils.ctx_managers.context_parallel.utils import (
AllGatherWithGrad,
apply_context_parallelism,
)
from axolotl.utils.schemas.enums import RingAttnFunc
@@ -57,46 +62,48 @@ class ContextParallelContextManager:
self.heads_k_stride = heads_k_stride
self._register_ring_attn()
# Set distributed info for local rank
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
# Store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
# Create a partially applied version of the apply_context_parallelism function
self.apply_context_parallelism = functools.partial(
apply_context_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
# SPDA CP initialization
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
enabled=self.context_parallel_degree > 1,
world_mesh=world_mesh,
model=model,
if self.backend == "flash_attention":
# Set distributed info for local rank
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Create a partially applied version of the apply_context_parallelism function
self.apply_context_parallelism = functools.partial(
apply_context_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func,
)
self.context_parallel_managers.append(ctx_manager)
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
else:
# SPDA device mesh init
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
# SDPA context parallel managers
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
world_mesh=world_mesh,
model=model,
)
self.context_parallel_managers.append(ctx_manager)
def __enter__(self):
self._register_model_hooks()
@@ -162,6 +169,28 @@ class ContextParallelContextManager:
return output
def make_sdpa_pre_hook(model_idx: int) -> Callable:
def cp_sdpa_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
with self.context_parallel_managers[model_idx](list(to_shard.values())):
return remaining_args, updated_kwargs
return cp_sdpa_pre_hook
# Register both hooks
for i, model in enumerate(self.models):
if self.backend == "flash_attention":
@@ -172,14 +201,6 @@ class ContextParallelContextManager:
model.register_forward_hook(cp_flash_post_hook)
)
else:
def make_sdpa_pre_hook(model_idx: int) -> Callable:
def cp_sdpa_pre_hook(_, args, kwargs):
with self.context_parallel_managers[model_idx]:
return args, kwargs
return cp_sdpa_pre_hook
self.hook_handles.append(
model.register_forward_pre_hook(
make_sdpa_pre_hook(i), with_kwargs=True
@@ -193,4 +214,3 @@ class ContextParallelContextManager:
output[key] = AllGatherWithGrad.apply(value, self.process_group)
return output