From 3d4562000866b81cdd80923353d388f5045b707c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 11 Aug 2025 09:34:41 -0400 Subject: [PATCH] remove prepare-from-posids patch (#3052) [skip ci] --- src/axolotl/loaders/patch_manager.py | 4 - .../modeling_flash_attention_utils.py | 87 ------------------- 2 files changed, 91 deletions(-) delete mode 100644 src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 795fc3e37..f1ca3c725 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -73,9 +73,6 @@ class PatchManager: self._apply_voxtral_patches() def _apply_transformers_patches(self): - from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import ( - patch_prepare_from_posids, - ) from axolotl.monkeypatch.transformers.trainer_loss_calc import ( patch_evaluation_loop, patch_maybe_log_save_evaluate, @@ -87,7 +84,6 @@ class PatchManager: and self.cfg.fsdp_version == 2 ) - patch_prepare_from_posids() patch_evaluation_loop(patch_fsdp2) patch_maybe_log_save_evaluate() diff --git a/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py b/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py deleted file mode 100644 index 1bd8ac6bc..000000000 --- a/src/axolotl/monkeypatch/transformers/modeling_flash_attention_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -Monkey patch to fix transformers.modeling_flash_attention_utils. - -see https://github.com/huggingface/transformers/pull/39653/files -""" - -import sys - -import torch - - -def _prepare_from_posids(query, key, value, position_ids): - """ - This function returns necessary arguments to call `flash_attn_varlen_func`. - All three query, key, value states will be flattened. - Cumulative lengths of each examples in the batch will be extracted from position_ids. - NOTE: ideally cumulative lengths should be prepared at the data collator stage - Arguments: - query (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - position_ids (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - Return: - query (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - query = query.contiguous().view(-1, query.size(-2), query.size(-1)) - key = key.contiguous().view(-1, key.size(-2), key.size(-1)) - value = value.contiguous().view(-1, value.size(-2), value.size(-1)) - - position_ids = position_ids.flatten() - indices_q = torch.arange( - position_ids.size(0), device=position_ids.device, dtype=torch.int32 - ) - - cu_seq_lens = torch.cat( - ( - indices_q[position_ids == 0], - torch.tensor( - position_ids.size(), device=position_ids.device, dtype=torch.int32 - ), - ) - ) - # NOTE: With torch compile, this will cause a graph break if you don't set - # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call - # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. - # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` - # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. - # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424 - # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing - # for some models (e.g. qwen2-vl). - max_length = cu_seq_lens.diff().max().item() - return ( - query, - key, - value, - indices_q, - (cu_seq_lens, cu_seq_lens), - (max_length, max_length), - ) - - -def patch_prepare_from_posids(): - import transformers.modeling_flash_attention_utils - - transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access - _prepare_from_posids - ) - setattr( - sys.modules["transformers.modeling_flash_attention_utils"], - "_prepare_from_posids", - _prepare_from_posids, - )