This commit is contained in:
Dan Saunders
2025-06-14 17:40:21 +00:00
parent 7a88de4fa8
commit f8f87321bd
6 changed files with 75 additions and 64 deletions

View File

@@ -7,13 +7,11 @@ from __future__ import annotations
import os import os
from collections import defaultdict from collections import defaultdict
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, Literal, Optional from typing import Callable, Literal, Optional
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
import datasets import datasets
import torch import torch
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.utils.data import ( from torch.utils.data import (
BatchSampler, BatchSampler,
DataLoader, DataLoader,

View File

@@ -11,7 +11,6 @@ from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import torch import torch
import torch.distributed as dist
import transformers.modelcard import transformers.modelcard
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset from datasets import Dataset
@@ -33,7 +32,7 @@ from axolotl.loaders import (
load_processor, load_processor,
load_tokenizer, load_tokenizer,
) )
from axolotl.utils.ctx_managers.context_parallel import ContextParallelContextManager from axolotl.utils.ctx_managers import ContextParallelContextManager
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except

View File

@@ -1,6 +1,5 @@
"""Init for context manager submodule.""" """Init for context manager submodule."""
from .context_parallel.manager import ContextParallelContextManager from .context_parallel.manager import ContextParallelContextManager
__all__ = ["ContextParallelContextManager"] __all__ = ["ContextParallelContextManager"]

View File

@@ -35,11 +35,11 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
import contextlib import contextlib
from typing import Callable, Generator, Optional, Union from typing import Callable, Generator, Optional, Union
from axolotl.utils.dict import DictDefault
import torch import torch
from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import BlockMask from torch.nn.attention.flex_attention import BlockMask
from transformers import PreTrainedModel from transformers import PreTrainedModel
@@ -85,7 +85,6 @@ def get_context_parallel_manager(
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends. dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
Args: Args:
enabled: Whether context parallel is enabled. Default: False
world_mesh: Global device mesh. world_mesh: Global device mesh.
model: Model to apply context parallelism to. model: Model to apply context parallelism to.
@@ -102,12 +101,8 @@ def get_context_parallel_manager(
"Context parallel is enabled but no context parallel device mesh is provided." "Context parallel is enabled but no context parallel device mesh is provided."
) )
# TODO: context parallel for multimodal models requires extra work # TODO: context parallel for multimodal models requires extra work
if not isinstance(model, TransformerDecoder): # if not isinstance(model, TransformerDecoder):
raise ValueError("Context parallel is only supported for text models") # raise ValueError("Context parallel is only supported for text models")
# TODO: this is a hacky proxy for whether we use flex for chunked attention
# remove this once flex is supported
if any([layer.mask_mod is not None for layer in model.layers]):
raise ValueError("Context parallel with flex attention is not yet supported")
model_buffers = list(model.buffers()) model_buffers = list(model.buffers())
@contextlib.contextmanager @contextlib.contextmanager

View File

@@ -17,8 +17,13 @@ from axolotl.monkeypatch.ring_attn import (
patch_prepare_device_mesh, patch_prepare_device_mesh,
register_ring_attn, register_ring_attn,
) )
from axolotl.utils.ctx_managers.context_parallel.utils import AllGatherWithGrad, apply_context_parallelism from axolotl.utils.ctx_managers.context_parallel.distributed import (
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager get_context_parallel_manager,
)
from axolotl.utils.ctx_managers.context_parallel.utils import (
AllGatherWithGrad,
apply_context_parallelism,
)
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@@ -57,46 +62,48 @@ class ContextParallelContextManager:
self.heads_k_stride = heads_k_stride self.heads_k_stride = heads_k_stride
self._register_ring_attn() self._register_ring_attn()
# Set distributed info for local rank # Store hook handles for removal
self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = [] self.hook_handles: list[RemovableHandle] = []
# Store original sequence length and padding information if self.backend == "flash_attention":
self.original_seq_len = 0 # Set distributed info for local rank
self.pad_len = 0 self.process_group = get_ring_attn_group()
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Create a partially applied version of the apply_context_parallelism function # Create a partially applied version of the apply_context_parallelism function
self.apply_context_parallelism = functools.partial( self.apply_context_parallelism = functools.partial(
apply_context_parallelism, apply_context_parallelism,
local_rank=self.local_rank, local_rank=self.local_rank,
local_world_size=self.local_world_size, local_world_size=self.local_world_size,
gradient_accumulation_steps=self.gradient_accumulation_steps, gradient_accumulation_steps=self.gradient_accumulation_steps,
ring_attn_func=self.ring_attn_func, ring_attn_func=self.ring_attn_func,
)
# SPDA CP initialization
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
enabled=self.context_parallel_degree > 1,
world_mesh=world_mesh,
model=model,
) )
self.context_parallel_managers.append(ctx_manager)
# Store original sequence length and padding information
self.original_seq_len = 0
self.pad_len = 0
else:
# SPDA device mesh init
world_size = dist.get_world_size()
mesh_shape = (
world_size // self.context_parallel_degree,
self.context_parallel_degree,
)
world_mesh = dist.DeviceMesh(
"cuda",
torch.tensor(list(range(world_size))).reshape(mesh_shape),
mesh_dim_names=("dp", "cp"),
)
# SDPA context parallel managers
self.context_parallel_managers = []
for model in models:
ctx_manager = get_context_parallel_manager(
world_mesh=world_mesh,
model=model,
)
self.context_parallel_managers.append(ctx_manager)
def __enter__(self): def __enter__(self):
self._register_model_hooks() self._register_model_hooks()
@@ -162,6 +169,28 @@ class ContextParallelContextManager:
return output return output
def make_sdpa_pre_hook(model_idx: int) -> Callable:
def cp_sdpa_pre_hook(_, args, kwargs):
# Get parameter names from the model's forward function
forward_params = list(
inspect.signature(self.models[0].forward).parameters.keys()
)
updated_kwargs = kwargs.copy()
for i, arg in enumerate(args):
if i < len(forward_params):
updated_kwargs[forward_params[i]] = arg
# Any excess positional arguments are kept as-is
remaining_args = args[len(forward_params) :]
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
with self.context_parallel_managers[model_idx](list(to_shard.values())):
return remaining_args, updated_kwargs
return cp_sdpa_pre_hook
# Register both hooks # Register both hooks
for i, model in enumerate(self.models): for i, model in enumerate(self.models):
if self.backend == "flash_attention": if self.backend == "flash_attention":
@@ -172,14 +201,6 @@ class ContextParallelContextManager:
model.register_forward_hook(cp_flash_post_hook) model.register_forward_hook(cp_flash_post_hook)
) )
else: else:
def make_sdpa_pre_hook(model_idx: int) -> Callable:
def cp_sdpa_pre_hook(_, args, kwargs):
with self.context_parallel_managers[model_idx]:
return args, kwargs
return cp_sdpa_pre_hook
self.hook_handles.append( self.hook_handles.append(
model.register_forward_pre_hook( model.register_forward_pre_hook(
make_sdpa_pre_hook(i), with_kwargs=True make_sdpa_pre_hook(i), with_kwargs=True
@@ -193,4 +214,3 @@ class ContextParallelContextManager:
output[key] = AllGatherWithGrad.apply(value, self.process_group) output[key] = AllGatherWithGrad.apply(value, self.process_group)
return output return output