diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 9b8f4ba92..70e443cb3 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -7,13 +7,11 @@ from __future__ import annotations import os from collections import defaultdict from functools import partial, wraps -from typing import Any, Callable, Literal, Optional +from typing import Callable, Literal, Optional -from axolotl.utils.ctx_managers.utils import get_context_parallel_manager import datasets import torch from datasets import Dataset -from torch import nn from torch.utils.data import ( BatchSampler, DataLoader, diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d60f432aa..5d2e77b4f 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -11,7 +11,6 @@ from pathlib import Path from typing import Any, Dict import torch -import torch.distributed as dist import transformers.modelcard from accelerate.utils import save_fsdp_model from datasets import Dataset @@ -33,7 +32,7 @@ from axolotl.loaders import ( load_processor, load_tokenizer, ) -from axolotl.utils.ctx_managers.context_parallel import ContextParallelContextManager +from axolotl.utils.ctx_managers import ContextParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except diff --git a/src/axolotl/utils/ctx_managers/__init__.py b/src/axolotl/utils/ctx_managers/__init__.py index 15eec49e4..b5e70eb71 100644 --- a/src/axolotl/utils/ctx_managers/__init__.py +++ b/src/axolotl/utils/ctx_managers/__init__.py @@ -1,6 +1,5 @@ """Init for context manager submodule.""" - from .context_parallel.manager import ContextParallelContextManager __all__ = ["ContextParallelContextManager"] diff --git a/src/axolotl/utils/ctx_managers/context_parallel.py b/src/axolotl/utils/ctx_managers/context_parallel.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/axolotl/utils/ctx_managers/context_parallel/distributed.py b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py index c2aa603c6..7d7d774d1 100644 --- a/src/axolotl/utils/ctx_managers/context_parallel/distributed.py +++ b/src/axolotl/utils/ctx_managers/context_parallel/distributed.py @@ -35,11 +35,11 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5 import contextlib from typing import Callable, Generator, Optional, Union +from axolotl.utils.dict import DictDefault import torch - from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental._attention import set_rotate_method -from torch.nn.attention import sdpa_kernel, SDPBackend +from torch.nn.attention import SDPBackend, sdpa_kernel from torch.nn.attention.flex_attention import BlockMask from transformers import PreTrainedModel @@ -85,7 +85,6 @@ def get_context_parallel_manager( dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends. Args: - enabled: Whether context parallel is enabled. Default: False world_mesh: Global device mesh. model: Model to apply context parallelism to. @@ -102,12 +101,8 @@ def get_context_parallel_manager( "Context parallel is enabled but no context parallel device mesh is provided." ) # TODO: context parallel for multimodal models requires extra work - if not isinstance(model, TransformerDecoder): - raise ValueError("Context parallel is only supported for text models") - # TODO: this is a hacky proxy for whether we use flex for chunked attention - # remove this once flex is supported - if any([layer.mask_mod is not None for layer in model.layers]): - raise ValueError("Context parallel with flex attention is not yet supported") + # if not isinstance(model, TransformerDecoder): + # raise ValueError("Context parallel is only supported for text models") model_buffers = list(model.buffers()) @contextlib.contextmanager @@ -132,4 +127,4 @@ def get_context_parallel_manager( with sdpa_context(cp_context): yield - return context \ No newline at end of file + return context diff --git a/src/axolotl/utils/ctx_managers/context_parallel/manager.py b/src/axolotl/utils/ctx_managers/context_parallel/manager.py index bdc98e1cd..bc16c66cd 100644 --- a/src/axolotl/utils/ctx_managers/context_parallel/manager.py +++ b/src/axolotl/utils/ctx_managers/context_parallel/manager.py @@ -17,8 +17,13 @@ from axolotl.monkeypatch.ring_attn import ( patch_prepare_device_mesh, register_ring_attn, ) -from axolotl.utils.ctx_managers.context_parallel.utils import AllGatherWithGrad, apply_context_parallelism -from axolotl.utils.ctx_managers.utils import get_context_parallel_manager +from axolotl.utils.ctx_managers.context_parallel.distributed import ( + get_context_parallel_manager, +) +from axolotl.utils.ctx_managers.context_parallel.utils import ( + AllGatherWithGrad, + apply_context_parallelism, +) from axolotl.utils.schemas.enums import RingAttnFunc @@ -57,46 +62,48 @@ class ContextParallelContextManager: self.heads_k_stride = heads_k_stride self._register_ring_attn() - # Set distributed info for local rank - self.process_group = get_ring_attn_group() - self.local_rank = dist.get_rank(self.process_group) - self.local_world_size = dist.get_world_size(self.process_group) - - # Will store hook handles for removal + # Store hook handles for removal self.hook_handles: list[RemovableHandle] = [] - # Store original sequence length and padding information - self.original_seq_len = 0 - self.pad_len = 0 - - # Create a partially applied version of the apply_context_parallelism function - self.apply_context_parallelism = functools.partial( - apply_context_parallelism, - local_rank=self.local_rank, - local_world_size=self.local_world_size, - gradient_accumulation_steps=self.gradient_accumulation_steps, - ring_attn_func=self.ring_attn_func, - ) - - # SPDA CP initialization - world_size = dist.get_world_size() - mesh_shape = ( - world_size // self.context_parallel_degree, - self.context_parallel_degree, - ) - world_mesh = dist.DeviceMesh( - "cuda", - torch.tensor(list(range(world_size))).reshape(mesh_shape), - mesh_dim_names=("dp", "cp"), - ) - self.context_parallel_managers = [] - for model in models: - ctx_manager = get_context_parallel_manager( - enabled=self.context_parallel_degree > 1, - world_mesh=world_mesh, - model=model, + if self.backend == "flash_attention": + # Set distributed info for local rank + self.process_group = get_ring_attn_group() + self.local_rank = dist.get_rank(self.process_group) + self.local_world_size = dist.get_world_size(self.process_group) + + # Create a partially applied version of the apply_context_parallelism function + self.apply_context_parallelism = functools.partial( + apply_context_parallelism, + local_rank=self.local_rank, + local_world_size=self.local_world_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + ring_attn_func=self.ring_attn_func, ) - self.context_parallel_managers.append(ctx_manager) + + # Store original sequence length and padding information + self.original_seq_len = 0 + self.pad_len = 0 + else: + # SPDA device mesh init + world_size = dist.get_world_size() + mesh_shape = ( + world_size // self.context_parallel_degree, + self.context_parallel_degree, + ) + world_mesh = dist.DeviceMesh( + "cuda", + torch.tensor(list(range(world_size))).reshape(mesh_shape), + mesh_dim_names=("dp", "cp"), + ) + + # SDPA context parallel managers + self.context_parallel_managers = [] + for model in models: + ctx_manager = get_context_parallel_manager( + world_mesh=world_mesh, + model=model, + ) + self.context_parallel_managers.append(ctx_manager) def __enter__(self): self._register_model_hooks() @@ -162,6 +169,28 @@ class ContextParallelContextManager: return output + def make_sdpa_pre_hook(model_idx: int) -> Callable: + def cp_sdpa_pre_hook(_, args, kwargs): + # Get parameter names from the model's forward function + forward_params = list( + inspect.signature(self.models[0].forward).parameters.keys() + ) + + updated_kwargs = kwargs.copy() + for i, arg in enumerate(args): + if i < len(forward_params): + updated_kwargs[forward_params[i]] = arg + + # Any excess positional arguments are kept as-is + remaining_args = args[len(forward_params) :] + + to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1} + + with self.context_parallel_managers[model_idx](list(to_shard.values())): + return remaining_args, updated_kwargs + + return cp_sdpa_pre_hook + # Register both hooks for i, model in enumerate(self.models): if self.backend == "flash_attention": @@ -172,14 +201,6 @@ class ContextParallelContextManager: model.register_forward_hook(cp_flash_post_hook) ) else: - - def make_sdpa_pre_hook(model_idx: int) -> Callable: - def cp_sdpa_pre_hook(_, args, kwargs): - with self.context_parallel_managers[model_idx]: - return args, kwargs - - return cp_sdpa_pre_hook - self.hook_handles.append( model.register_forward_pre_hook( make_sdpa_pre_hook(i), with_kwargs=True @@ -193,4 +214,3 @@ class ContextParallelContextManager: output[key] = AllGatherWithGrad.apply(value, self.process_group) return output -