support for tiledmlp for GPT-OSS (#3116)
* fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS * add logging back * update deps
This commit is contained in:
@@ -13,7 +13,7 @@ packaging==23.2
|
|||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft>=0.17.0
|
peft>=0.17.0
|
||||||
transformers==4.55.3
|
transformers==4.55.4
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.10.0
|
accelerate==1.10.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -127,7 +127,7 @@ extras_require = {
|
|||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.17.2",
|
"deepspeed==0.17.5",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -149,14 +149,12 @@ class PatchManager:
|
|||||||
def _apply_flex_attention_patches(self):
|
def _apply_flex_attention_patches(self):
|
||||||
"""Apply patches for flexible attention."""
|
"""Apply patches for flexible attention."""
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.flex_attention:
|
||||||
# from axolotl.monkeypatch.attention.flex_attn import (
|
from axolotl.monkeypatch.attention.flex_attn import (
|
||||||
# patch_flex_make_mask,
|
patch_flex_wrapper,
|
||||||
# patch_flex_wrapper,
|
)
|
||||||
# )
|
|
||||||
#
|
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||||
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||||
# patch_flex_wrapper(**flex_attn_compile_kwargs)
|
|
||||||
# patch_flex_make_mask()
|
|
||||||
if self.cfg.sample_packing:
|
if self.cfg.sample_packing:
|
||||||
from axolotl.core.attention.flex_block_mask import (
|
from axolotl.core.attention.flex_block_mask import (
|
||||||
patch_create_causal_mask,
|
patch_create_causal_mask,
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Flex attention monkey patch"""
|
"""Flex attention monkey patch"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional, Tuple, Union
|
from packaging import version
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -46,19 +46,33 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
|||||||
"""
|
"""
|
||||||
self.training = None
|
self.training = None
|
||||||
if not self._is_flex_compiled or training != self.training:
|
if not self._is_flex_compiled or training != self.training:
|
||||||
|
self.training = training
|
||||||
|
if is_torch_less_or_equal("2.5.1"):
|
||||||
|
self._compiled_flex_attention = torch.compile(
|
||||||
|
flex_attention, dynamic=False
|
||||||
|
)
|
||||||
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
|
||||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||||
self.training = training
|
elif version.parse(_torch_version).base_version == "2.6.0" and training:
|
||||||
LOG.info(
|
self._compiled_flex_attention = torch.compile(
|
||||||
"Compiling flex attention with kwargs: %s. This may take a while...",
|
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
|
||||||
flex_attn_compile_kwargs,
|
)
|
||||||
)
|
# Fallback, usually the most recent torch 2.7.x+ versions
|
||||||
self._compiled_flex_attention = torch.compile(
|
else:
|
||||||
flex_attention,
|
LOG.info(
|
||||||
**flex_attn_compile_kwargs,
|
"Compiling flex attention with kwargs: %s. This may take a while...",
|
||||||
)
|
flex_attn_compile_kwargs,
|
||||||
LOG.info("Flex attention compiled successfully.")
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
self._compiled_flex_attention = torch.compile(
|
||||||
|
flex_attention,
|
||||||
|
**flex_attn_compile_kwargs,
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"Flex attention compiled successfully.", main_process_only=True
|
||||||
|
)
|
||||||
|
|
||||||
self._is_flex_compiled = True
|
self._is_flex_compiled = True
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
@@ -68,139 +82,3 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
|||||||
sys.modules[
|
sys.modules[
|
||||||
"transformers.integrations.flex_attention"
|
"transformers.integrations.flex_attention"
|
||||||
].WrappedFlexAttention = WrappedFlexAttention
|
].WrappedFlexAttention = WrappedFlexAttention
|
||||||
|
|
||||||
|
|
||||||
def patch_flex_make_mask():
|
|
||||||
is_torch_2_6 = torch.__version__.startswith("2.6")
|
|
||||||
|
|
||||||
if not is_torch_2_6:
|
|
||||||
return
|
|
||||||
|
|
||||||
from torch.nn.attention.flex_attention import (
|
|
||||||
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
|
|
||||||
)
|
|
||||||
from torch.nn.attention.flex_attention import (
|
|
||||||
BlockMask,
|
|
||||||
)
|
|
||||||
from torch.nn.attention.flex_attention import (
|
|
||||||
create_block_mask as create_block_causal_mask_flex,
|
|
||||||
)
|
|
||||||
|
|
||||||
Offset = Union[torch.Tensor, int]
|
|
||||||
|
|
||||||
def patched_make_flex_block_causal_mask(
|
|
||||||
attention_mask_2d: torch.Tensor,
|
|
||||||
attention_chunk_size: Optional[int] = None,
|
|
||||||
query_length=None,
|
|
||||||
key_length=None,
|
|
||||||
offsets: Optional[Tuple[Offset, Offset]] = None,
|
|
||||||
) -> "BlockMask":
|
|
||||||
"""
|
|
||||||
Create a block causal document mask for a batch of sequences, both packed and unpacked.
|
|
||||||
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
|
|
||||||
The resultant BlockMask is a compressed representation of the full block causal
|
|
||||||
mask. BlockMask is essential for performant computation of flex attention.
|
|
||||||
See: https://pytorch.org/blog/flexattention/
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
|
|
||||||
of shape (batch_size, total_seq_len). e.g.
|
|
||||||
|
|
||||||
For unpacked sequence:
|
|
||||||
[[1, 1, 1, 1, 0, 0, 0],
|
|
||||||
[1, 1, 1, 1, 1, 0, 0]]
|
|
||||||
|
|
||||||
For packed sequence:
|
|
||||||
[[1, 1, 1, 2, 2, 2, 0],
|
|
||||||
[1, 1, 2, 2, 2, 3, 3]]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BlockMask
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch_size, total_seq_len = attention_mask_2d.shape
|
|
||||||
if not key_length:
|
|
||||||
key_length = total_seq_len
|
|
||||||
if not query_length:
|
|
||||||
query_length = total_seq_len
|
|
||||||
attention_mask_2d = torch.nn.functional.pad(
|
|
||||||
attention_mask_2d,
|
|
||||||
value=0,
|
|
||||||
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
|
|
||||||
)
|
|
||||||
device = attention_mask_2d.device
|
|
||||||
document_ids = attention_mask_2d.clone()
|
|
||||||
|
|
||||||
if attention_chunk_size is not None:
|
|
||||||
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
|
|
||||||
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
|
|
||||||
attention_chunk_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Instead of passing a tensor mask, flex attention requires a mask_mod function
|
|
||||||
# that determines which elements of QK^T should be included in the attention
|
|
||||||
# computation prior to the softmax. For sample packing, we need both the
|
|
||||||
# logic for both causal mask and document mask. See PyTorch's official
|
|
||||||
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
|
|
||||||
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|
||||||
"""
|
|
||||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
|
||||||
and a block diagonal document mask.
|
|
||||||
|
|
||||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
|
||||||
for an illustration.
|
|
||||||
"""
|
|
||||||
causal_mask = q_idx >= kv_idx # not valid when decoding
|
|
||||||
document_mask = (
|
|
||||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
|
||||||
)
|
|
||||||
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
|
|
||||||
final_mask = causal_mask & padding_mask & document_mask
|
|
||||||
return final_mask
|
|
||||||
|
|
||||||
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|
||||||
"""
|
|
||||||
Combines the chunk mask with the causal mask for chunked attention.
|
|
||||||
"""
|
|
||||||
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
|
|
||||||
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
|
|
||||||
return chunk_mask & causal_doc_mask
|
|
||||||
|
|
||||||
mask_mod_maybe_combined = (
|
|
||||||
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
|
|
||||||
)
|
|
||||||
|
|
||||||
if offsets is not None:
|
|
||||||
q_offset = offsets[0]
|
|
||||||
kv_offset = offsets[1]
|
|
||||||
|
|
||||||
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|
||||||
offset_q = q_idx + q_offset
|
|
||||||
offset_kv = kv_idx + kv_offset
|
|
||||||
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
|
|
||||||
|
|
||||||
else:
|
|
||||||
mask_mod = mask_mod_maybe_combined
|
|
||||||
return create_block_causal_mask_flex(
|
|
||||||
mask_mod=mask_mod,
|
|
||||||
B=batch_size,
|
|
||||||
H=None, # attention head
|
|
||||||
Q_LEN=query_length,
|
|
||||||
KV_LEN=key_length,
|
|
||||||
device=device,
|
|
||||||
_compile=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for n in tuple(sys.modules):
|
|
||||||
if ".modeling_" in n:
|
|
||||||
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
|
|
||||||
sys.modules[
|
|
||||||
n
|
|
||||||
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
|
|
||||||
sys.modules[
|
|
||||||
n
|
|
||||||
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
|
|
||||||
|
|
||||||
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
|
|
||||||
patched_make_flex_block_causal_mask
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -8,6 +8,94 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSpeedTiledMLPMoE(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
fn,
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
shards,
|
||||||
|
compute_params,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
ctx.fn = fn
|
||||||
|
ctx.self = self
|
||||||
|
ctx.shards = shards
|
||||||
|
ctx.compute_params = [p for p in compute_params if p.requires_grad]
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
|
||||||
|
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||||
|
with torch.no_grad():
|
||||||
|
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
||||||
|
|
||||||
|
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
|
||||||
|
if isinstance(output_shards[0], tuple):
|
||||||
|
tuple_dim_idx = [1, 0]
|
||||||
|
output_unsharded = tuple(
|
||||||
|
torch.cat(
|
||||||
|
[output_shard[i] for output_shard in output_shards],
|
||||||
|
dim=tuple_dim_idx[i],
|
||||||
|
)
|
||||||
|
for i in range(len(output_shards[0]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_unsharded = torch.cat(output_shards, dim=1)
|
||||||
|
|
||||||
|
return output_unsharded
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *grads) -> torch.Tensor:
|
||||||
|
fn = ctx.fn
|
||||||
|
(x,) = ctx.saved_tensors
|
||||||
|
self = ctx.self
|
||||||
|
shards = ctx.shards
|
||||||
|
compute_params = ctx.compute_params
|
||||||
|
is_tuple_output = ctx.is_tuple_output
|
||||||
|
|
||||||
|
x_requires_grad = x.requires_grad
|
||||||
|
x = x.detach()
|
||||||
|
# detach() unsets `x.requires_grad`, so restore it
|
||||||
|
x.requires_grad_(x_requires_grad)
|
||||||
|
|
||||||
|
incoming_grad = grads[0]
|
||||||
|
x_grad = torch.zeros_like(x)
|
||||||
|
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||||
|
|
||||||
|
shard_step = x_shards[0].numel()
|
||||||
|
for i, x_shard in enumerate(x_shards):
|
||||||
|
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
|
||||||
|
if compute_params is not None:
|
||||||
|
if i + 1 < shards:
|
||||||
|
for param in compute_params:
|
||||||
|
param.ds_grad_is_ready = False
|
||||||
|
else:
|
||||||
|
# last shard, can add the grad
|
||||||
|
for param in compute_params:
|
||||||
|
param.ds_grad_is_ready = True
|
||||||
|
|
||||||
|
x_shard.requires_grad_(x_requires_grad)
|
||||||
|
|
||||||
|
shard_offset = i * shard_step
|
||||||
|
x_shard.grad = (
|
||||||
|
x_grad.view(-1)
|
||||||
|
.narrow(0, shard_offset, x_shard.numel())
|
||||||
|
.view_as(x_shard)
|
||||||
|
)
|
||||||
|
incoming_grad_shard = (
|
||||||
|
incoming_grad.view(-1)
|
||||||
|
.narrow(0, shard_offset, x_shard.numel())
|
||||||
|
.view_as(x_shard)
|
||||||
|
)
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = fn(self, x_shard)
|
||||||
|
if is_tuple_output:
|
||||||
|
torch.autograd.backward(output[0], incoming_grad_shard)
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, incoming_grad_shard)
|
||||||
|
|
||||||
|
return (None, None, x_grad, None, None)
|
||||||
|
|
||||||
|
|
||||||
class TiledMLP(torch.autograd.Function):
|
class TiledMLP(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
TiledMLP implementation using gradient hooks
|
TiledMLP implementation using gradient hooks
|
||||||
@@ -31,7 +119,18 @@ class TiledMLP(torch.autograd.Function):
|
|||||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
||||||
output_unsharded = torch.cat(output_shards, dim=1)
|
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
|
||||||
|
if isinstance(output_shards[0], tuple):
|
||||||
|
tuple_dim_idx = [1, 0]
|
||||||
|
output_unsharded = tuple(
|
||||||
|
torch.cat(
|
||||||
|
[output_shard[i] for output_shard in output_shards],
|
||||||
|
dim=tuple_dim_idx[i],
|
||||||
|
)
|
||||||
|
for i in range(len(output_shards[0]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_unsharded = torch.cat(output_shards, dim=1)
|
||||||
|
|
||||||
return output_unsharded
|
return output_unsharded
|
||||||
|
|
||||||
@@ -42,6 +141,7 @@ class TiledMLP(torch.autograd.Function):
|
|||||||
self = ctx.self
|
self = ctx.self
|
||||||
shards = ctx.shards
|
shards = ctx.shards
|
||||||
compute_params = ctx.compute_params
|
compute_params = ctx.compute_params
|
||||||
|
is_tuple_output = ctx.is_tuple_output
|
||||||
|
|
||||||
x_requires_grad = x.requires_grad
|
x_requires_grad = x.requires_grad
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
@@ -76,7 +176,10 @@ class TiledMLP(torch.autograd.Function):
|
|||||||
|
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
output = fn(self, x_shard)
|
output = fn(self, x_shard)
|
||||||
torch.autograd.backward(output, incoming_grad_shard)
|
if is_tuple_output:
|
||||||
|
torch.autograd.backward(output[0], incoming_grad_shard)
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, incoming_grad_shard)
|
||||||
|
|
||||||
# Clean up hooks
|
# Clean up hooks
|
||||||
grad_accumulator.cleanup()
|
grad_accumulator.cleanup()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
TiledMLP as DeepSpeedTiledMLP,
|
TiledMLP as DeepSpeedTiledMLP,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
|
from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and MLP class
|
# Dynamically import the module and MLP class
|
||||||
@@ -64,7 +64,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
for p in self._compute_params
|
for p in self._compute_params
|
||||||
)
|
)
|
||||||
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||||
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
|
if model_type == "gpt_oss":
|
||||||
|
self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE
|
||||||
|
else:
|
||||||
|
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
|
||||||
else:
|
else:
|
||||||
self._tiled_mlp_dist_impl = TiledMLP
|
self._tiled_mlp_dist_impl = TiledMLP
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user