support for tiledmlp for GPT-OSS (#3116)

* fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS

* add logging back

* update deps
This commit is contained in:
Wing Lian
2025-08-29 13:52:49 -04:00
committed by GitHub
parent 7ed40f1d70
commit 0094a2d744
6 changed files with 144 additions and 162 deletions

View File

@@ -13,7 +13,7 @@ packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.55.3
transformers==4.55.4
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0

View File

@@ -127,7 +127,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.2",
"deepspeed==0.17.5",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -149,14 +149,12 @@ class PatchManager:
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
# from axolotl.monkeypatch.attention.flex_attn import (
# patch_flex_make_mask,
# patch_flex_wrapper,
# )
#
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
# patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_make_mask()
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,

View File

@@ -1,11 +1,11 @@
"""Flex attention monkey patch"""
import sys
from typing import Optional, Tuple, Union
from packaging import version
import torch
import transformers
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -46,19 +46,33 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
"""
self.training = None
if not self._is_flex_compiled or training != self.training:
self.training = training
if is_torch_less_or_equal("2.5.1"):
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False
)
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info("Flex attention compiled successfully.")
elif version.parse(_torch_version).base_version == "2.6.0" and training:
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
# Fallback, usually the most recent torch 2.7.x+ versions
else:
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
main_process_only=True,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info(
"Flex attention compiled successfully.", main_process_only=True
)
self._is_flex_compiled = True
def __call__(self):
@@ -68,139 +82,3 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
sys.modules[
"transformers.integrations.flex_attention"
].WrappedFlexAttention = WrappedFlexAttention
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
if not is_torch_2_6:
return
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
)
from torch.nn.attention.flex_attention import (
BlockMask,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
Offset = Union[torch.Tensor, int]
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d,
value=0,
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
mask_mod_maybe_combined = (
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
)
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = mask_mod_maybe_combined
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -8,6 +8,94 @@ from typing import List
import torch
class DeepSpeedTiledMLPMoE(torch.autograd.Function):
@staticmethod
def forward(
ctx,
fn,
self,
x,
shards,
compute_params,
) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad
x = x.detach()
# detach() unsets `x.requires_grad`, so restore it
x.requires_grad_(x_requires_grad)
incoming_grad = grads[0]
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
if compute_params is not None:
if i + 1 < shards:
for param in compute_params:
param.ds_grad_is_ready = False
else:
# last shard, can add the grad
for param in compute_params:
param.ds_grad_is_ready = True
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
with torch.enable_grad():
output = fn(self, x_shard)
if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
return (None, None, x_grad, None, None)
class TiledMLP(torch.autograd.Function):
"""
TiledMLP implementation using gradient hooks
@@ -31,7 +119,18 @@ class TiledMLP(torch.autograd.Function):
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1)
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@@ -42,6 +141,7 @@ class TiledMLP(torch.autograd.Function):
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad
x = x.detach()
@@ -76,7 +176,10 @@ class TiledMLP(torch.autograd.Function):
with torch.enable_grad():
output = fn(self, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
# Clean up hooks
grad_accumulator.cleanup()

View File

@@ -17,7 +17,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
TiledMLP as DeepSpeedTiledMLP,
)
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP
try:
# Dynamically import the module and MLP class
@@ -64,7 +64,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
for p in self._compute_params
)
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
if model_type == "gpt_oss":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE
else:
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
else:
self._tiled_mlp_dist_impl = TiledMLP