diff --git a/requirements.txt b/requirements.txt index 5accd13ed..9e3dbbca4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ packaging==23.2 huggingface_hub>=0.33.0 peft>=0.17.0 -transformers==4.55.3 +transformers==4.55.4 tokenizers>=0.21.1 accelerate==1.10.0 datasets==4.0.0 diff --git a/setup.py b/setup.py index 5bf9ae840..4cbc562e0 100644 --- a/setup.py +++ b/setup.py @@ -127,7 +127,7 @@ extras_require = { "yunchang==0.6.0", ], "deepspeed": [ - "deepspeed==0.17.2", + "deepspeed==0.17.5", "deepspeed-kernels", ], "mamba-ssm": [ diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index eafe89d29..94b307a62 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -149,14 +149,12 @@ class PatchManager: def _apply_flex_attention_patches(self): """Apply patches for flexible attention.""" if self.cfg.flex_attention: - # from axolotl.monkeypatch.attention.flex_attn import ( - # patch_flex_make_mask, - # patch_flex_wrapper, - # ) - # - # flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} - # patch_flex_wrapper(**flex_attn_compile_kwargs) - # patch_flex_make_mask() + from axolotl.monkeypatch.attention.flex_attn import ( + patch_flex_wrapper, + ) + + flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {} + patch_flex_wrapper(**flex_attn_compile_kwargs) if self.cfg.sample_packing: from axolotl.core.attention.flex_block_mask import ( patch_create_causal_mask, diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index f59b8abe2..65ccad533 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -1,11 +1,11 @@ """Flex attention monkey patch""" import sys -from typing import Optional, Tuple, Union +from packaging import version import torch import transformers - +from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -46,19 +46,33 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs): """ self.training = None if not self._is_flex_compiled or training != self.training: + self.training = training + if is_torch_less_or_equal("2.5.1"): + self._compiled_flex_attention = torch.compile( + flex_attention, dynamic=False + ) # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" # see https://github.com/pytorch/pytorch/issues/146260 for training - self.training = training - LOG.info( - "Compiling flex attention with kwargs: %s. This may take a while...", - flex_attn_compile_kwargs, - ) - self._compiled_flex_attention = torch.compile( - flex_attention, - **flex_attn_compile_kwargs, - ) - LOG.info("Flex attention compiled successfully.") + elif version.parse(_torch_version).base_version == "2.6.0" and training: + self._compiled_flex_attention = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" + ) + # Fallback, usually the most recent torch 2.7.x+ versions + else: + LOG.info( + "Compiling flex attention with kwargs: %s. This may take a while...", + flex_attn_compile_kwargs, + main_process_only=True, + ) + self._compiled_flex_attention = torch.compile( + flex_attention, + **flex_attn_compile_kwargs, + ) + LOG.info( + "Flex attention compiled successfully.", main_process_only=True + ) + self._is_flex_compiled = True def __call__(self): @@ -68,139 +82,3 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs): sys.modules[ "transformers.integrations.flex_attention" ].WrappedFlexAttention = WrappedFlexAttention - - -def patch_flex_make_mask(): - is_torch_2_6 = torch.__version__.startswith("2.6") - - if not is_torch_2_6: - return - - from torch.nn.attention.flex_attention import ( - _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size, - ) - from torch.nn.attention.flex_attention import ( - BlockMask, - ) - from torch.nn.attention.flex_attention import ( - create_block_mask as create_block_causal_mask_flex, - ) - - Offset = Union[torch.Tensor, int] - - def patched_make_flex_block_causal_mask( - attention_mask_2d: torch.Tensor, - attention_chunk_size: Optional[int] = None, - query_length=None, - key_length=None, - offsets: Optional[Tuple[Offset, Offset]] = None, - ) -> "BlockMask": - """ - Create a block causal document mask for a batch of sequences, both packed and unpacked. - Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. - The resultant BlockMask is a compressed representation of the full block causal - mask. BlockMask is essential for performant computation of flex attention. - See: https://pytorch.org/blog/flexattention/ - - Args: - attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences - of shape (batch_size, total_seq_len). e.g. - - For unpacked sequence: - [[1, 1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 0, 0]] - - For packed sequence: - [[1, 1, 1, 2, 2, 2, 0], - [1, 1, 2, 2, 2, 3, 3]] - - Returns: - BlockMask - """ - - batch_size, total_seq_len = attention_mask_2d.shape - if not key_length: - key_length = total_seq_len - if not query_length: - query_length = total_seq_len - attention_mask_2d = torch.nn.functional.pad( - attention_mask_2d, - value=0, - pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))), - ) - device = attention_mask_2d.device - document_ids = attention_mask_2d.clone() - - if attention_chunk_size is not None: - # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] - chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // ( - attention_chunk_size - ) - - # Instead of passing a tensor mask, flex attention requires a mask_mod function - # that determines which elements of QK^T should be included in the attention - # computation prior to the softmax. For sample packing, we need both the - # logic for both causal mask and document mask. See PyTorch's official - # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods - def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): - """ - Defines the logic of a block causal mask by combining both a standard causal mask - and a block diagonal document mask. - - See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` - for an illustration. - """ - causal_mask = q_idx >= kv_idx # not valid when decoding - document_mask = ( - document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] - ) - padding_mask = attention_mask_2d[batch_idx, q_idx] > 0 - final_mask = causal_mask & padding_mask & document_mask - return final_mask - - def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): - """ - Combines the chunk mask with the causal mask for chunked attention. - """ - chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx] - causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx) - return chunk_mask & causal_doc_mask - - mask_mod_maybe_combined = ( - causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod - ) - - if offsets is not None: - q_offset = offsets[0] - kv_offset = offsets[1] - - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): - offset_q = q_idx + q_offset - offset_kv = kv_idx + kv_offset - return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv) - - else: - mask_mod = mask_mod_maybe_combined - return create_block_causal_mask_flex( - mask_mod=mask_mod, - B=batch_size, - H=None, # attention head - Q_LEN=query_length, - KV_LEN=key_length, - device=device, - _compile=True, - ) - - for n in tuple(sys.modules): - if ".modeling_" in n: - if hasattr(sys.modules[n], "make_flex_block_causal_mask"): - sys.modules[ - n - ].make_flex_block_causal_mask = patched_make_flex_block_causal_mask - sys.modules[ - n - ].make_flex_block_causal_mask = patched_make_flex_block_causal_mask - - transformers.integrations.flex_attention.make_flex_block_causal_mask = ( - patched_make_flex_block_causal_mask - ) diff --git a/src/axolotl/monkeypatch/tiled_mlp/base.py b/src/axolotl/monkeypatch/tiled_mlp/base.py index 3b7326bdb..2c9dc8e4c 100644 --- a/src/axolotl/monkeypatch/tiled_mlp/base.py +++ b/src/axolotl/monkeypatch/tiled_mlp/base.py @@ -8,6 +8,94 @@ from typing import List import torch +class DeepSpeedTiledMLPMoE(torch.autograd.Function): + @staticmethod + def forward( + ctx, + fn, + self, + x, + shards, + compute_params, + ) -> torch.Tensor: + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + + ctx.is_tuple_output = isinstance(output_shards[0], tuple) + if isinstance(output_shards[0], tuple): + tuple_dim_idx = [1, 0] + output_unsharded = tuple( + torch.cat( + [output_shard[i] for output_shard in output_shards], + dim=tuple_dim_idx[i], + ) + for i in range(len(output_shards[0])) + ) + else: + output_unsharded = torch.cat(output_shards, dim=1) + + return output_unsharded + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + fn = ctx.fn + (x,) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + is_tuple_output = ctx.is_tuple_output + + x_requires_grad = x.requires_grad + x = x.detach() + # detach() unsets `x.requires_grad`, so restore it + x.requires_grad_(x_requires_grad) + + incoming_grad = grads[0] + x_grad = torch.zeros_like(x) + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + + shard_step = x_shards[0].numel() + for i, x_shard in enumerate(x_shards): + # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run + if compute_params is not None: + if i + 1 < shards: + for param in compute_params: + param.ds_grad_is_ready = False + else: + # last shard, can add the grad + for param in compute_params: + param.ds_grad_is_ready = True + + x_shard.requires_grad_(x_requires_grad) + + shard_offset = i * shard_step + x_shard.grad = ( + x_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + incoming_grad_shard = ( + incoming_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + with torch.enable_grad(): + output = fn(self, x_shard) + if is_tuple_output: + torch.autograd.backward(output[0], incoming_grad_shard) + else: + torch.autograd.backward(output, incoming_grad_shard) + + return (None, None, x_grad, None, None) + + class TiledMLP(torch.autograd.Function): """ TiledMLP implementation using gradient hooks @@ -31,7 +119,18 @@ class TiledMLP(torch.autograd.Function): x_shards = list(torch.chunk(x, chunks=shards, dim=1)) with torch.no_grad(): output_shards = [fn(self, x_shard) for x_shard in x_shards] - output_unsharded = torch.cat(output_shards, dim=1) + ctx.is_tuple_output = isinstance(output_shards[0], tuple) + if isinstance(output_shards[0], tuple): + tuple_dim_idx = [1, 0] + output_unsharded = tuple( + torch.cat( + [output_shard[i] for output_shard in output_shards], + dim=tuple_dim_idx[i], + ) + for i in range(len(output_shards[0])) + ) + else: + output_unsharded = torch.cat(output_shards, dim=1) return output_unsharded @@ -42,6 +141,7 @@ class TiledMLP(torch.autograd.Function): self = ctx.self shards = ctx.shards compute_params = ctx.compute_params + is_tuple_output = ctx.is_tuple_output x_requires_grad = x.requires_grad x = x.detach() @@ -76,7 +176,10 @@ class TiledMLP(torch.autograd.Function): with torch.enable_grad(): output = fn(self, x_shard) - torch.autograd.backward(output, incoming_grad_shard) + if is_tuple_output: + torch.autograd.backward(output[0], incoming_grad_shard) + else: + torch.autograd.backward(output, incoming_grad_shard) # Clean up hooks grad_accumulator.cleanup() diff --git a/src/axolotl/monkeypatch/tiled_mlp/patch.py b/src/axolotl/monkeypatch/tiled_mlp/patch.py index 7cdc6d3a3..c0f89236b 100644 --- a/src/axolotl/monkeypatch/tiled_mlp/patch.py +++ b/src/axolotl/monkeypatch/tiled_mlp/patch.py @@ -17,7 +17,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None): TiledMLP as DeepSpeedTiledMLP, ) - from axolotl.monkeypatch.tiled_mlp.base import TiledMLP + from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP try: # Dynamically import the module and MLP class @@ -64,7 +64,10 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None): for p in self._compute_params ) ) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": - self._tiled_mlp_dist_impl = DeepSpeedTiledMLP + if model_type == "gpt_oss": + self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE + else: + self._tiled_mlp_dist_impl = DeepSpeedTiledMLP else: self._tiled_mlp_dist_impl = TiledMLP