diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py index 8b69c2c49..8ca6d06b0 100644 --- a/src/axolotl/monkeypatch/attention/flex_attn.py +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -1,48 +1,162 @@ """Flex attention monkey patch""" +from typing import Optional, Tuple, Union + import torch import transformers -def patch_flex(): +def patch_flex_wrapper(): is_torch_2_6 = torch.__version__.startswith("2.6") is_transformers_below_4_51 = transformers.__version__ < "4.51.0" - if is_torch_2_6 and is_transformers_below_4_51: - from torch.nn.attention.flex_attention import flex_attention + if not (is_torch_2_6 and is_transformers_below_4_51): + return - class WrappedFlexAttention: + from torch.nn.attention.flex_attention import flex_attention + + class WrappedFlexAttention: + """ + We are doing a singleton class so that flex attention is compiled once when it's first called. + """ + + _instance = None + _is_flex_compiled = False + _compiled_flex_attention = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + # Create a new instance if one doesn't already exist + cls._instance = super().__new__(cls) + return cls._instance + + @torch.compiler.disable(recursive=False) + def __init__(self): """ - We are doing a singleton class so that flex attention is compiled once when it's first called. + Initialize or update the singleton instance. """ + if not self._is_flex_compiled: + self._compiled_flex_attention = torch.compile( + flex_attention, + dynamic=False, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + ) + self._is_flex_compiled = True - _instance = None - _is_flex_compiled = False - _compiled_flex_attention = None + def __call__(self): + return self._compiled_flex_attention - def __new__(cls, *args, **kwargs): - if cls._instance is None: - # Create a new instance if one doesn't already exist - cls._instance = super().__new__(cls) - return cls._instance + transformers.integrations.flex_attention.WrappedFlexAttention = WrappedFlexAttention - @torch.compiler.disable(recursive=False) - def __init__(self): - """ - Initialize or update the singleton instance. - """ - if not self._is_flex_compiled: - self._compiled_flex_attention = torch.compile( - flex_attention, - dynamic=False, - mode="max-autotune-no-cudagraphs", - fullgraph=True, - ) - self._is_flex_compiled = True - def __call__(self): - return self._compiled_flex_attention +def patch_flex_make_mask(): + is_torch_2_6 = torch.__version__.startswith("2.6") + is_transformers_below_4_51 = transformers.__version__ < "4.51.0" - transformers.integrations.flex_attention.WrappedFlexAttention = ( - WrappedFlexAttention + if not (is_torch_2_6 and is_transformers_below_4_51): + return + + from torch.nn.attention.flex_attention import ( + BlockMask, + ) + from torch.nn.attention.flex_attention import ( + create_block_mask as create_block_causal_mask_flex, + ) + + Offset = Union[torch.Tensor, int] + + def make_flex_block_causal_mask( + attention_mask_2d: torch.Tensor, + attention_chunk_size: Optional[int] = None, + query_length=None, + key_length=None, + offsets: Optional[Tuple[Offset, Offset]] = None, + ) -> "BlockMask": + """ + Create a block causal document mask for a batch of sequences, both packed and unpacked. + Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. + The resultant BlockMask is a compressed representation of the full block causal + mask. BlockMask is essential for performant computation of flex attention. + See: https://pytorch.org/blog/flexattention/ + + Args: + attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences + of shape (batch_size, total_seq_len). e.g. + + For unpacked sequence: + [[1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 0, 0]] + + For packed sequence: + [[1, 1, 1, 2, 2, 2, 0], + [1, 1, 2, 2, 2, 3, 3]] + + Returns: + BlockMask + """ + + batch_size, total_seq_len = attention_mask_2d.shape + if not key_length: + key_length = total_seq_len + if not query_length: + query_length = total_seq_len + attention_mask_2d = torch.nn.functional.pad( + attention_mask_2d, value=0, pad=(0, key_length) ) + device = attention_mask_2d.device + document_ids = attention_mask_2d.clone() + + if attention_chunk_size is not None: + # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3] + document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // ( + attention_chunk_size + ) + + # Instead of passing a tensor mask, flex attention requires a mask_mod function + # that determines which elements of QK^T should be included in the attention + # computation prior to the softmax. For sample packing, we need both the + # logic for both causal mask and document mask. See PyTorch's official + # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods + def causal_mask_mod( + batch_idx, head_idx, q_idx, kv_idx + ): # pylint: disable=unused-argument + """ + Defines the logic of a block causal mask by combining both a standard causal mask + and a block diagonal document mask. + + See :func:`~torchtune.modules.attention_utils.create_block_causal_mask` + for an illustration. + """ + causal_mask = q_idx >= kv_idx # not valid when decoding + document_mask = ( + document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] + ) + padding_mask = attention_mask_2d[batch_idx, q_idx] > 0 + final_mask = causal_mask & padding_mask & document_mask + return final_mask + + if offsets is not None: + q_offset = offsets[0] + kv_offset = offsets[1] + + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + offset_q = q_idx + q_offset + offset_kv = kv_idx + kv_offset + return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv) + + else: + mask_mod = causal_mask_mod + return create_block_causal_mask_flex( + mask_mod=mask_mod, + B=batch_size, + H=None, # attention head + Q_LEN=query_length, + KV_LEN=key_length, + device=device, + _compile=True, + ) + + transformers.integrations.flex_attention.make_flex_block_causal_mask = ( + make_flex_block_causal_mask + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 89f35d7eb..c2bddeeec 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -217,7 +217,7 @@ def save_trained_model( # Handle FSDP state dict type state_dict_type = "FULL_STATE_DICT" - if trainer.is_fsdp_enabled: + if trainer.is_fsdp_enabled and str(cfg.fsdp_config.fsdp_version) != "2": if cfg.fsdp_final_state_dict_type: state_dict_type = cfg.fsdp_final_state_dict_type trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 663aa1740..0e1329b97 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -889,9 +889,13 @@ class ModelLoader: self.model_config._attn_implementation = ( # pylint: disable=protected-access "flex_attention" ) - from axolotl.monkeypatch.attention.flex_attn import patch_flex + from axolotl.monkeypatch.attention.flex_attn import ( + patch_flex_make_mask, + patch_flex_wrapper, + ) - patch_flex() + patch_flex_wrapper() + patch_flex_make_mask() elif self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 16c17a15b..d71fa25c8 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -486,7 +486,7 @@ class TestMultiGPULlama: "gradient_checkpointing": True, "output_dir": temp_dir, "learning_rate": 0.00001, - "optimizer": "adamw_8bit", + "optimizer": "adamw_torch_8bit", "lr_scheduler": "cosine", "fsdp": [ "auto_wrap", @@ -529,7 +529,7 @@ class TestMultiGPULlama: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high" ) def test_fsdp_qlora_prequant_packed(self, temp_dir):