progress
This commit is contained in:
@@ -7,13 +7,11 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, Callable, Literal, Optional
|
from typing import Callable, Literal, Optional
|
||||||
|
|
||||||
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
DataLoader,
|
DataLoader,
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate.utils import save_fsdp_model
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
@@ -33,7 +32,7 @@ from axolotl.loaders import (
|
|||||||
load_processor,
|
load_processor,
|
||||||
load_tokenizer,
|
load_tokenizer,
|
||||||
)
|
)
|
||||||
from axolotl.utils.ctx_managers.context_parallel import ContextParallelContextManager
|
from axolotl.utils.ctx_managers import ContextParallelContextManager
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import cleanup_distributed
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Init for context manager submodule."""
|
"""Init for context manager submodule."""
|
||||||
|
|
||||||
|
|
||||||
from .context_parallel.manager import ContextParallelContextManager
|
from .context_parallel.manager import ContextParallelContextManager
|
||||||
|
|
||||||
__all__ = ["ContextParallelContextManager"]
|
__all__ = ["ContextParallelContextManager"]
|
||||||
|
|||||||
@@ -35,11 +35,11 @@ https://github.com/pytorch/torchtune/blob/2344509cf83bd886538fe3e8263e5145d1afb5
|
|||||||
import contextlib
|
import contextlib
|
||||||
from typing import Callable, Generator, Optional, Union
|
from typing import Callable, Generator, Optional, Union
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.distributed.tensor.experimental import context_parallel
|
from torch.distributed.tensor.experimental import context_parallel
|
||||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
@@ -85,7 +85,6 @@ def get_context_parallel_manager(
|
|||||||
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
|
dimension, this context manager also calls into _get_sdpa_context to filter to acceptable SDPA backends.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
enabled: Whether context parallel is enabled. Default: False
|
|
||||||
world_mesh: Global device mesh.
|
world_mesh: Global device mesh.
|
||||||
model: Model to apply context parallelism to.
|
model: Model to apply context parallelism to.
|
||||||
|
|
||||||
@@ -102,12 +101,8 @@ def get_context_parallel_manager(
|
|||||||
"Context parallel is enabled but no context parallel device mesh is provided."
|
"Context parallel is enabled but no context parallel device mesh is provided."
|
||||||
)
|
)
|
||||||
# TODO: context parallel for multimodal models requires extra work
|
# TODO: context parallel for multimodal models requires extra work
|
||||||
if not isinstance(model, TransformerDecoder):
|
# if not isinstance(model, TransformerDecoder):
|
||||||
raise ValueError("Context parallel is only supported for text models")
|
# raise ValueError("Context parallel is only supported for text models")
|
||||||
# TODO: this is a hacky proxy for whether we use flex for chunked attention
|
|
||||||
# remove this once flex is supported
|
|
||||||
if any([layer.mask_mod is not None for layer in model.layers]):
|
|
||||||
raise ValueError("Context parallel with flex attention is not yet supported")
|
|
||||||
model_buffers = list(model.buffers())
|
model_buffers = list(model.buffers())
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -132,4 +127,4 @@ def get_context_parallel_manager(
|
|||||||
with sdpa_context(cp_context):
|
with sdpa_context(cp_context):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|||||||
@@ -17,8 +17,13 @@ from axolotl.monkeypatch.ring_attn import (
|
|||||||
patch_prepare_device_mesh,
|
patch_prepare_device_mesh,
|
||||||
register_ring_attn,
|
register_ring_attn,
|
||||||
)
|
)
|
||||||
from axolotl.utils.ctx_managers.context_parallel.utils import AllGatherWithGrad, apply_context_parallelism
|
from axolotl.utils.ctx_managers.context_parallel.distributed import (
|
||||||
from axolotl.utils.ctx_managers.utils import get_context_parallel_manager
|
get_context_parallel_manager,
|
||||||
|
)
|
||||||
|
from axolotl.utils.ctx_managers.context_parallel.utils import (
|
||||||
|
AllGatherWithGrad,
|
||||||
|
apply_context_parallelism,
|
||||||
|
)
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
|
|
||||||
|
|
||||||
@@ -57,46 +62,48 @@ class ContextParallelContextManager:
|
|||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
self._register_ring_attn()
|
self._register_ring_attn()
|
||||||
|
|
||||||
# Set distributed info for local rank
|
# Store hook handles for removal
|
||||||
self.process_group = get_ring_attn_group()
|
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
|
||||||
self.local_world_size = dist.get_world_size(self.process_group)
|
|
||||||
|
|
||||||
# Will store hook handles for removal
|
|
||||||
self.hook_handles: list[RemovableHandle] = []
|
self.hook_handles: list[RemovableHandle] = []
|
||||||
|
|
||||||
# Store original sequence length and padding information
|
if self.backend == "flash_attention":
|
||||||
self.original_seq_len = 0
|
# Set distributed info for local rank
|
||||||
self.pad_len = 0
|
self.process_group = get_ring_attn_group()
|
||||||
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
# Create a partially applied version of the apply_context_parallelism function
|
self.local_world_size = dist.get_world_size(self.process_group)
|
||||||
self.apply_context_parallelism = functools.partial(
|
|
||||||
apply_context_parallelism,
|
# Create a partially applied version of the apply_context_parallelism function
|
||||||
local_rank=self.local_rank,
|
self.apply_context_parallelism = functools.partial(
|
||||||
local_world_size=self.local_world_size,
|
apply_context_parallelism,
|
||||||
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
local_rank=self.local_rank,
|
||||||
ring_attn_func=self.ring_attn_func,
|
local_world_size=self.local_world_size,
|
||||||
)
|
gradient_accumulation_steps=self.gradient_accumulation_steps,
|
||||||
|
ring_attn_func=self.ring_attn_func,
|
||||||
# SPDA CP initialization
|
|
||||||
world_size = dist.get_world_size()
|
|
||||||
mesh_shape = (
|
|
||||||
world_size // self.context_parallel_degree,
|
|
||||||
self.context_parallel_degree,
|
|
||||||
)
|
|
||||||
world_mesh = dist.DeviceMesh(
|
|
||||||
"cuda",
|
|
||||||
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
|
||||||
mesh_dim_names=("dp", "cp"),
|
|
||||||
)
|
|
||||||
self.context_parallel_managers = []
|
|
||||||
for model in models:
|
|
||||||
ctx_manager = get_context_parallel_manager(
|
|
||||||
enabled=self.context_parallel_degree > 1,
|
|
||||||
world_mesh=world_mesh,
|
|
||||||
model=model,
|
|
||||||
)
|
)
|
||||||
self.context_parallel_managers.append(ctx_manager)
|
|
||||||
|
# Store original sequence length and padding information
|
||||||
|
self.original_seq_len = 0
|
||||||
|
self.pad_len = 0
|
||||||
|
else:
|
||||||
|
# SPDA device mesh init
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
mesh_shape = (
|
||||||
|
world_size // self.context_parallel_degree,
|
||||||
|
self.context_parallel_degree,
|
||||||
|
)
|
||||||
|
world_mesh = dist.DeviceMesh(
|
||||||
|
"cuda",
|
||||||
|
torch.tensor(list(range(world_size))).reshape(mesh_shape),
|
||||||
|
mesh_dim_names=("dp", "cp"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# SDPA context parallel managers
|
||||||
|
self.context_parallel_managers = []
|
||||||
|
for model in models:
|
||||||
|
ctx_manager = get_context_parallel_manager(
|
||||||
|
world_mesh=world_mesh,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
self.context_parallel_managers.append(ctx_manager)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self._register_model_hooks()
|
self._register_model_hooks()
|
||||||
@@ -162,6 +169,28 @@ class ContextParallelContextManager:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def make_sdpa_pre_hook(model_idx: int) -> Callable:
|
||||||
|
def cp_sdpa_pre_hook(_, args, kwargs):
|
||||||
|
# Get parameter names from the model's forward function
|
||||||
|
forward_params = list(
|
||||||
|
inspect.signature(self.models[0].forward).parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_kwargs = kwargs.copy()
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if i < len(forward_params):
|
||||||
|
updated_kwargs[forward_params[i]] = arg
|
||||||
|
|
||||||
|
# Any excess positional arguments are kept as-is
|
||||||
|
remaining_args = args[len(forward_params) :]
|
||||||
|
|
||||||
|
to_shard = {k: v for k, v in updated_kwargs.items() if v.ndim > 1}
|
||||||
|
|
||||||
|
with self.context_parallel_managers[model_idx](list(to_shard.values())):
|
||||||
|
return remaining_args, updated_kwargs
|
||||||
|
|
||||||
|
return cp_sdpa_pre_hook
|
||||||
|
|
||||||
# Register both hooks
|
# Register both hooks
|
||||||
for i, model in enumerate(self.models):
|
for i, model in enumerate(self.models):
|
||||||
if self.backend == "flash_attention":
|
if self.backend == "flash_attention":
|
||||||
@@ -172,14 +201,6 @@ class ContextParallelContextManager:
|
|||||||
model.register_forward_hook(cp_flash_post_hook)
|
model.register_forward_hook(cp_flash_post_hook)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def make_sdpa_pre_hook(model_idx: int) -> Callable:
|
|
||||||
def cp_sdpa_pre_hook(_, args, kwargs):
|
|
||||||
with self.context_parallel_managers[model_idx]:
|
|
||||||
return args, kwargs
|
|
||||||
|
|
||||||
return cp_sdpa_pre_hook
|
|
||||||
|
|
||||||
self.hook_handles.append(
|
self.hook_handles.append(
|
||||||
model.register_forward_pre_hook(
|
model.register_forward_pre_hook(
|
||||||
make_sdpa_pre_hook(i), with_kwargs=True
|
make_sdpa_pre_hook(i), with_kwargs=True
|
||||||
@@ -193,4 +214,3 @@ class ContextParallelContextManager:
|
|||||||
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
output[key] = AllGatherWithGrad.apply(value, self.process_group)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user