progress
This commit is contained in:
@@ -7,13 +7,11 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
from typing import Callable, Literal, Optional
|
||||
|
||||
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
|
||||
@@ -11,7 +11,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers.modelcard
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
@@ -33,7 +32,7 @@ from axolotl.loaders import (
|
||||
load_processor,
|
||||
load_tokenizer,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel import ContextParallelContextManager
|
||||
from axolotl.utils.ctx_managers import ContextParallelContextManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Init for context manager submodule."""
|
||||
|
||||
|
||||
from .context_parallel.manager import ContextParallelContextManager
|
||||
|
||||
__all__ = ["ContextParallelContextManager"]
|
||||
|
||||
@@ -35,11 +35,11 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
|
||||
import contextlib
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
import torch
|
||||
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
@@ -85,7 +85,6 @@ def get_context_parallel_manager(
|
||||
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
|
||||
|
||||
Args:
|
||||
enabled: Whether context parallel is enabled. Default: False
|
||||
world_mesh: Global device mesh.
|
||||
model: Model to apply context parallelism to.
|
||||
|
||||
@@ -102,12 +101,8 @@ def get_context_parallel_manager(
|
||||
"Context parallel is enabled but no context parallel device mesh is provided."
|
||||
)
|
||||
# TODO: context parallel for multimodal models requires extra work
|
||||
if not isinstance(model, TransformerDecoder):
|
||||
raise ValueError("Context parallel is only supported for text models")
|
||||
# TODO: this is a hacky proxy for whether we use flex for chunked attention
|
||||
# remove this once flex is supported
|
||||
if any([layer.mask_mod is not None for layer in model.layers]):
|
||||
raise ValueError("Context parallel with flex attention is not yet supported")
|
||||
# if not isinstance(model, TransformerDecoder):
|
||||
# raise ValueError("Context parallel is only supported for text models")
|
||||
model_buffers = list(model.buffers())
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -132,4 +127,4 @@ def get_context_parallel_manager(
|
||||
with sdpa_context(cp_context):
|
||||
yield
|
||||
|
||||
return context
|
||||
return context
|
||||
|
||||
@@ -17,8 +17,13 @@ from axolotl.monkeypatch.ring_attn import (
|
||||
patch_prepare_device_mesh,
|
||||
register_ring_attn,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.utils import AllGatherWithGrad, apply_context_parallelism
|
||||
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
|
||||
from axolotl.utils.ctx_managers.context_parallel.distributed import (
|
||||
get_context_parallel_manager,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.context_parallel.utils import (
|
||||
AllGatherWithGrad,
|
||||
apply_context_parallelism,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
|
||||
|
||||
@@ -57,46 +62,48 @@ class ContextParallelContextManager:
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self._register_ring_attn()
|
||||
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Will store hook handles for removal
|
||||
# Store hook handles for removal
|
||||
self.hook_handles: list[RemovableHandle] = []
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
|
||||
# Create a partially applied version of the apply_context_parallelism function
|
||||
self.apply_context_parallelism = functools.partial(
|
||||
apply_context_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
|
||||
# SPDA CP initialization
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // self.context_parallel_degree,
|
||||
self.context_parallel_degree,
|
||||
)
|
||||
world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
self.context_parallel_managers = []
|
||||
for model in models:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
enabled=self.context_parallel_degree > 1,
|
||||
world_mesh=world_mesh,
|
||||
model=model,
|
||||
if self.backend == "flash_attention":
|
||||
# Set distributed info for local rank
|
||||
self.process_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.local_world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
# Create a partially applied version of the apply_context_parallelism function
|
||||
self.apply_context_parallelism = functools.partial(
|
||||
apply_context_parallelism,
|
||||
local_rank=self.local_rank,
|
||||
local_world_size=self.local_world_size,
|
||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
)
|
||||
self.context_parallel_managers.append(ctx_manager)
|
||||
|
||||
# Store original sequence length and padding information
|
||||
self.original_seq_len = 0
|
||||
self.pad_len = 0
|
||||
else:
|
||||
# SPDA device mesh init
|
||||
world_size = dist.get_world_size()
|
||||
mesh_shape = (
|
||||
world_size // self.context_parallel_degree,
|
||||
self.context_parallel_degree,
|
||||
)
|
||||
world_mesh = dist.DeviceMesh(
|
||||
"cuda",
|
||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||
mesh_dim_names=("dp", "cp"),
|
||||
)
|
||||
|
||||
# SDPA context parallel managers
|
||||
self.context_parallel_managers = []
|
||||
for model in models:
|
||||
ctx_manager = get_context_parallel_manager(
|
||||
world_mesh=world_mesh,
|
||||
model=model,
|
||||
)
|
||||
self.context_parallel_managers.append(ctx_manager)
|
||||
|
||||
def __enter__(self):
|
||||
self._register_model_hooks()
|
||||
@@ -162,6 +169,28 @@ class ContextParallelContextManager:
|
||||
|
||||
return output
|
||||
|
||||
def make_sdpa_pre_hook(model_idx: int) -> Callable:
|
||||
def cp_sdpa_pre_hook(_, args, kwargs):
|
||||
# Get parameter names from the model's forward function
|
||||
forward_params = list(
|
||||
inspect.signature(self.models[0].forward).parameters.keys()
|
||||
)
|
||||
|
||||
updated_kwargs = kwargs.copy()
|
||||
for i, arg in enumerate(args):
|
||||
if i < len(forward_params):
|
||||
updated_kwargs[forward_params[i]] = arg
|
||||
|
||||
# Any excess positional arguments are kept as-is
|
||||
remaining_args = args[len(forward_params) :]
|
||||
|
||||
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
|
||||
|
||||
with self.context_parallel_managers[model_idx](list(to_shard.values())):
|
||||
return remaining_args, updated_kwargs
|
||||
|
||||
return cp_sdpa_pre_hook
|
||||
|
||||
# Register both hooks
|
||||
for i, model in enumerate(self.models):
|
||||
if self.backend == "flash_attention":
|
||||
@@ -172,14 +201,6 @@ class ContextParallelContextManager:
|
||||
model.register_forward_hook(cp_flash_post_hook)
|
||||
)
|
||||
else:
|
||||
|
||||
def make_sdpa_pre_hook(model_idx: int) -> Callable:
|
||||
def cp_sdpa_pre_hook(_, args, kwargs):
|
||||
with self.context_parallel_managers[model_idx]:
|
||||
return args, kwargs
|
||||
|
||||
return cp_sdpa_pre_hook
|
||||
|
||||
self.hook_handles.append(
|
||||
model.register_forward_pre_hook(
|
||||
make_sdpa_pre_hook(i), with_kwargs=True
|
||||
@@ -193,4 +214,3 @@ class ContextParallelContextManager:
|
||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user