support for tiledmlp for GPT-OSS (#3116)
* fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS * add logging back * update deps
This commit is contained in:
@@ -13,7 +13,7 @@ packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft>=0.17.0
|
||||
transformers==4.55.3
|
||||
transformers==4.55.4
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
|
||||
2
setup.py
2
setup.py
@@ -127,7 +127,7 @@ extras_require = {
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.2",
|
||||
"deepspeed==0.17.5",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -149,14 +149,12 @@ class PatchManager:
|
||||
def _apply_flex_attention_patches(self):
|
||||
"""Apply patches for flexible attention."""
|
||||
if self.cfg.flex_attention:
|
||||
# from axolotl.monkeypatch.attention.flex_attn import (
|
||||
# patch_flex_make_mask,
|
||||
# patch_flex_wrapper,
|
||||
# )
|
||||
#
|
||||
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
# patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
# patch_flex_make_mask()
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.core.attention.flex_block_mask import (
|
||||
patch_create_causal_mask,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Flex attention monkey patch"""
|
||||
|
||||
import sys
|
||||
from typing import Optional, Tuple, Union
|
||||
from packaging import version
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -46,19 +46,33 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||
"""
|
||||
self.training = None
|
||||
if not self._is_flex_compiled or training != self.training:
|
||||
self.training = training
|
||||
if is_torch_less_or_equal("2.5.1"):
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention, dynamic=False
|
||||
)
|
||||
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||
self.training = training
|
||||
LOG.info(
|
||||
"Compiling flex attention with kwargs: %s. This may take a while...",
|
||||
flex_attn_compile_kwargs,
|
||||
)
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention,
|
||||
**flex_attn_compile_kwargs,
|
||||
)
|
||||
LOG.info("Flex attention compiled successfully.")
|
||||
elif version.parse(_torch_version).base_version == "2.6.0" and training:
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
# Fallback, usually the most recent torch 2.7.x+ versions
|
||||
else:
|
||||
LOG.info(
|
||||
"Compiling flex attention with kwargs: %s. This may take a while...",
|
||||
flex_attn_compile_kwargs,
|
||||
main_process_only=True,
|
||||
)
|
||||
self._compiled_flex_attention = torch.compile(
|
||||
flex_attention,
|
||||
**flex_attn_compile_kwargs,
|
||||
)
|
||||
LOG.info(
|
||||
"Flex attention compiled successfully.", main_process_only=True
|
||||
)
|
||||
|
||||
self._is_flex_compiled = True
|
||||
|
||||
def __call__(self):
|
||||
@@ -68,139 +82,3 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||
sys.modules[
|
||||
"transformers.integrations.flex_attention"
|
||||
].WrappedFlexAttention = WrappedFlexAttention
|
||||
|
||||
|
||||
def patch_flex_make_mask():
|
||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
||||
|
||||
if not is_torch_2_6:
|
||||
return
|
||||
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import (
|
||||
create_block_mask as create_block_causal_mask_flex,
|
||||
)
|
||||
|
||||
Offset = Union[torch.Tensor, int]
|
||||
|
||||
def patched_make_flex_block_causal_mask(
|
||||
attention_mask_2d: torch.Tensor,
|
||||
attention_chunk_size: Optional[int] = None,
|
||||
query_length=None,
|
||||
key_length=None,
|
||||
offsets: Optional[Tuple[Offset, Offset]] = None,
|
||||
) -> "BlockMask":
|
||||
"""
|
||||
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
||||
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
||||
The resultant BlockMask is a compressed representation of the full block causal
|
||||
mask. BlockMask is essential for performant computation of flex attention.
|
||||
See: https://pytorch.org/blog/flexattention/
|
||||
|
||||
Args:
|
||||
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
|
||||
of shape (batch_size, total_seq_len). e.g.
|
||||
|
||||
For unpacked sequence:
|
||||
[[1, 1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 0, 0]]
|
||||
|
||||
For packed sequence:
|
||||
[[1, 1, 1, 2, 2, 2, 0],
|
||||
[1, 1, 2, 2, 2, 3, 3]]
|
||||
|
||||
Returns:
|
||||
BlockMask
|
||||
"""
|
||||
|
||||
batch_size, total_seq_len = attention_mask_2d.shape
|
||||
if not key_length:
|
||||
key_length = total_seq_len
|
||||
if not query_length:
|
||||
query_length = total_seq_len
|
||||
attention_mask_2d = torch.nn.functional.pad(
|
||||
attention_mask_2d,
|
||||
value=0,
|
||||
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
|
||||
)
|
||||
device = attention_mask_2d.device
|
||||
document_ids = attention_mask_2d.clone()
|
||||
|
||||
if attention_chunk_size is not None:
|
||||
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
|
||||
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
|
||||
attention_chunk_size
|
||||
)
|
||||
|
||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
||||
# that determines which elements of QK^T should be included in the attention
|
||||
# computation prior to the softmax. For sample packing, we need both the
|
||||
# logic for both causal mask and document mask. See PyTorch's official
|
||||
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
|
||||
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = (
|
||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
)
|
||||
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
|
||||
final_mask = causal_mask & padding_mask & document_mask
|
||||
return final_mask
|
||||
|
||||
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Combines the chunk mask with the causal mask for chunked attention.
|
||||
"""
|
||||
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
|
||||
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
|
||||
return chunk_mask & causal_doc_mask
|
||||
|
||||
mask_mod_maybe_combined = (
|
||||
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
|
||||
)
|
||||
|
||||
if offsets is not None:
|
||||
q_offset = offsets[0]
|
||||
kv_offset = offsets[1]
|
||||
|
||||
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
offset_q = q_idx + q_offset
|
||||
offset_kv = kv_idx + kv_offset
|
||||
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
|
||||
|
||||
else:
|
||||
mask_mod = mask_mod_maybe_combined
|
||||
return create_block_causal_mask_flex(
|
||||
mask_mod=mask_mod,
|
||||
B=batch_size,
|
||||
H=None, # attention head
|
||||
Q_LEN=query_length,
|
||||
KV_LEN=key_length,
|
||||
device=device,
|
||||
_compile=True,
|
||||
)
|
||||
|
||||
for n in tuple(sys.modules):
|
||||
if ".modeling_" in n:
|
||||
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
|
||||
sys.modules[
|
||||
n
|
||||
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
|
||||
sys.modules[
|
||||
n
|
||||
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
|
||||
|
||||
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
|
||||
patched_make_flex_block_causal_mask
|
||||
)
|
||||
|
||||
@@ -8,6 +8,94 @@ from typing import List
|
||||
import torch
|
||||
|
||||
|
||||
class DeepSpeedTiledMLPMoE(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
fn,
|
||||
self,
|
||||
x,
|
||||
shards,
|
||||
compute_params,
|
||||
) -> torch.Tensor:
|
||||
ctx.fn = fn
|
||||
ctx.self = self
|
||||
ctx.shards = shards
|
||||
ctx.compute_params = [p for p in compute_params if p.requires_grad]
|
||||
ctx.save_for_backward(x)
|
||||
|
||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||
with torch.no_grad():
|
||||
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
||||
|
||||
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
|
||||
if isinstance(output_shards[0], tuple):
|
||||
tuple_dim_idx = [1, 0]
|
||||
output_unsharded = tuple(
|
||||
torch.cat(
|
||||
[output_shard[i] for output_shard in output_shards],
|
||||
dim=tuple_dim_idx[i],
|
||||
)
|
||||
for i in range(len(output_shards[0]))
|
||||
)
|
||||
else:
|
||||
output_unsharded = torch.cat(output_shards, dim=1)
|
||||
|
||||
return output_unsharded
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads) -> torch.Tensor:
|
||||
fn = ctx.fn
|
||||
(x,) = ctx.saved_tensors
|
||||
self = ctx.self
|
||||
shards = ctx.shards
|
||||
compute_params = ctx.compute_params
|
||||
is_tuple_output = ctx.is_tuple_output
|
||||
|
||||
x_requires_grad = x.requires_grad
|
||||
x = x.detach()
|
||||
# detach() unsets `x.requires_grad`, so restore it
|
||||
x.requires_grad_(x_requires_grad)
|
||||
|
||||
incoming_grad = grads[0]
|
||||
x_grad = torch.zeros_like(x)
|
||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||
|
||||
shard_step = x_shards[0].numel()
|
||||
for i, x_shard in enumerate(x_shards):
|
||||
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
|
||||
if compute_params is not None:
|
||||
if i + 1 < shards:
|
||||
for param in compute_params:
|
||||
param.ds_grad_is_ready = False
|
||||
else:
|
||||
# last shard, can add the grad
|
||||
for param in compute_params:
|
||||
param.ds_grad_is_ready = True
|
||||
|
||||
x_shard.requires_grad_(x_requires_grad)
|
||||
|
||||
shard_offset = i * shard_step
|
||||
x_shard.grad = (
|
||||
x_grad.view(-1)
|
||||
.narrow(0, shard_offset, x_shard.numel())
|
||||
.view_as(x_shard)
|
||||
)
|
||||
incoming_grad_shard = (
|
||||
incoming_grad.view(-1)
|
||||
.narrow(0, shard_offset, x_shard.numel())
|
||||
.view_as(x_shard)
|
||||
)
|
||||
with torch.enable_grad():
|
||||
output = fn(self, x_shard)
|
||||
if is_tuple_output:
|
||||
torch.autograd.backward(output[0], incoming_grad_shard)
|
||||
else:
|
||||
torch.autograd.backward(output, incoming_grad_shard)
|
||||
|
||||
return (None, None, x_grad, None, None)
|
||||
|
||||
|
||||
class TiledMLP(torch.autograd.Function):
|
||||
"""
|
||||
TiledMLP implementation using gradient hooks
|
||||
@@ -31,7 +119,18 @@ class TiledMLP(torch.autograd.Function):
|
||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||
with torch.no_grad():
|
||||
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
||||
output_unsharded = torch.cat(output_shards, dim=1)
|
||||
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
|
||||
if isinstance(output_shards[0], tuple):
|
||||
tuple_dim_idx = [1, 0]
|
||||
output_unsharded = tuple(
|
||||
torch.cat(
|
||||
[output_shard[i] for output_shard in output_shards],
|
||||
dim=tuple_dim_idx[i],
|
||||
)
|
||||
for i in range(len(output_shards[0]))
|
||||
)
|
||||
else:
|
||||
output_unsharded = torch.cat(output_shards, dim=1)
|
||||
|
||||
return output_unsharded
|
||||
|
||||
@@ -42,6 +141,7 @@ class TiledMLP(torch.autograd.Function):
|
||||
self = ctx.self
|
||||
shards = ctx.shards
|
||||
compute_params = ctx.compute_params
|
||||
is_tuple_output = ctx.is_tuple_output
|
||||
|
||||
x_requires_grad = x.requires_grad
|
||||
x = x.detach()
|
||||
@@ -76,7 +176,10 @@ class TiledMLP(torch.autograd.Function):
|
||||
|
||||
with torch.enable_grad():
|
||||
output = fn(self, x_shard)
|
||||
torch.autograd.backward(output, incoming_grad_shard)
|
||||
if is_tuple_output:
|
||||
torch.autograd.backward(output[0], incoming_grad_shard)
|
||||
else:
|
||||
torch.autograd.backward(output, incoming_grad_shard)
|
||||
|
||||
# Clean up hooks
|
||||
grad_accumulator.cleanup()
|
||||
|
||||
@@ -17,7 +17,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||
TiledMLP as DeepSpeedTiledMLP,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
|
||||
from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP
|
||||
|
||||
try:
|
||||
# Dynamically import the module and MLP class
|
||||
@@ -64,7 +64,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||
for p in self._compute_params
|
||||
)
|
||||
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
|
||||
if model_type == "gpt_oss":
|
||||
self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE
|
||||
else:
|
||||
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
|
||||
else:
|
||||
self._tiled_mlp_dist_impl = TiledMLP
|
||||
|
||||
|
||||
Reference in New Issue
Block a user