support for tiledmlp for GPT-OSS (#3116)

* fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS

* add logging back

* update deps
This commit is contained in:
Wing Lian
2025-08-29 13:52:49 -04:00
committed by GitHub
parent 7ed40f1d70
commit 0094a2d744
6 changed files with 144 additions and 162 deletions

View File

@@ -13,7 +13,7 @@ packaging==23.2
huggingface_hub>=0.33.0 huggingface_hub>=0.33.0
peft>=0.17.0 peft>=0.17.0
transformers==4.55.3 transformers==4.55.4
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.10.0 accelerate==1.10.0
datasets==4.0.0 datasets==4.0.0

View File

@@ -127,7 +127,7 @@ extras_require = {
"yunchang==0.6.0", "yunchang==0.6.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.17.2", "deepspeed==0.17.5",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -149,14 +149,12 @@ class PatchManager:
def _apply_flex_attention_patches(self): def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention.""" """Apply patches for flexible attention."""
if self.cfg.flex_attention: if self.cfg.flex_attention:
# from axolotl.monkeypatch.attention.flex_attn import ( from axolotl.monkeypatch.attention.flex_attn import (
# patch_flex_make_mask, patch_flex_wrapper,
# patch_flex_wrapper, )
# )
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_make_mask()
if self.cfg.sample_packing: if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import ( from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask, patch_create_causal_mask,

View File

@@ -1,11 +1,11 @@
"""Flex attention monkey patch""" """Flex attention monkey patch"""
import sys import sys
from typing import Optional, Tuple, Union from packaging import version
import torch import torch
import transformers import transformers
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -46,19 +46,33 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
""" """
self.training = None self.training = None
if not self._is_flex_compiled or training != self.training: if not self._is_flex_compiled or training != self.training:
self.training = training
if is_torch_less_or_equal("2.5.1"):
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False
)
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training # see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training elif version.parse(_torch_version).base_version == "2.6.0" and training:
LOG.info( self._compiled_flex_attention = torch.compile(
"Compiling flex attention with kwargs: %s. This may take a while...", flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
flex_attn_compile_kwargs, )
) # Fallback, usually the most recent torch 2.7.x+ versions
self._compiled_flex_attention = torch.compile( else:
flex_attention, LOG.info(
**flex_attn_compile_kwargs, "Compiling flex attention with kwargs: %s. This may take a while...",
) flex_attn_compile_kwargs,
LOG.info("Flex attention compiled successfully.") main_process_only=True,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info(
"Flex attention compiled successfully.", main_process_only=True
)
self._is_flex_compiled = True self._is_flex_compiled = True
def __call__(self): def __call__(self):
@@ -68,139 +82,3 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
sys.modules[ sys.modules[
"transformers.integrations.flex_attention" "transformers.integrations.flex_attention"
].WrappedFlexAttention = WrappedFlexAttention ].WrappedFlexAttention = WrappedFlexAttention
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
if not is_torch_2_6:
return
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
)
from torch.nn.attention.flex_attention import (
BlockMask,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
Offset = Union[torch.Tensor, int]
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d,
value=0,
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
mask_mod_maybe_combined = (
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
)
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = mask_mod_maybe_combined
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -8,6 +8,94 @@ from typing import List
import torch import torch
class DeepSpeedTiledMLPMoE(torch.autograd.Function):
@staticmethod
def forward(
ctx,
fn,
self,
x,
shards,
compute_params,
) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad
x = x.detach()
# detach() unsets `x.requires_grad`, so restore it
x.requires_grad_(x_requires_grad)
incoming_grad = grads[0]
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
if compute_params is not None:
if i + 1 < shards:
for param in compute_params:
param.ds_grad_is_ready = False
else:
# last shard, can add the grad
for param in compute_params:
param.ds_grad_is_ready = True
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
with torch.enable_grad():
output = fn(self, x_shard)
if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
return (None, None, x_grad, None, None)
class TiledMLP(torch.autograd.Function): class TiledMLP(torch.autograd.Function):
""" """
TiledMLP implementation using gradient hooks TiledMLP implementation using gradient hooks
@@ -31,7 +119,18 @@ class TiledMLP(torch.autograd.Function):
x_shards = list(torch.chunk(x, chunks=shards, dim=1)) x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad(): with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards] output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1) ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded return output_unsharded
@@ -42,6 +141,7 @@ class TiledMLP(torch.autograd.Function):
self = ctx.self self = ctx.self
shards = ctx.shards shards = ctx.shards
compute_params = ctx.compute_params compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad x_requires_grad = x.requires_grad
x = x.detach() x = x.detach()
@@ -76,7 +176,10 @@ class TiledMLP(torch.autograd.Function):
with torch.enable_grad(): with torch.enable_grad():
output = fn(self, x_shard) output = fn(self, x_shard)
torch.autograd.backward(output, incoming_grad_shard) if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
# Clean up hooks # Clean up hooks
grad_accumulator.cleanup() grad_accumulator.cleanup()

View File

@@ -17,7 +17,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
TiledMLP as DeepSpeedTiledMLP, TiledMLP as DeepSpeedTiledMLP,
) )
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP
try: try:
# Dynamically import the module and MLP class # Dynamically import the module and MLP class
@@ -64,7 +64,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
for p in self._compute_params for p in self._compute_params
) )
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": ) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP if model_type == "gpt_oss":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE
else:
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
else: else:
self._tiled_mlp_dist_impl = TiledMLP self._tiled_mlp_dist_impl = TiledMLP