update transformers to 4.53.1 (#2844) [skip ci]

* update transformers to 4.53.0

* remove attention_mask from signature columns if using packing

* remove attention_mask column from dataloader

* update signature of flash attn forward for ring attn patch

* fix FSDP

* patch ring-flash-attn with upstream signature fix

* fix patch indentation level

* fix the patch

* add batch flattening smoke test with loss check that works in older transformers

* fix patch

* don't drop attention mask for flex

* more fixes

* patch create_causal_mask for packing w flex

* global torch manual_seed fixture

* tweak loss checks

* fix patch and use single batch for flex

* don't need to reload

* fix causal mask patch

* use transformers patch releasE

* make sure env var is string

* make sure to drop attention mask for flex w packing for latest transformers patch release

* tweak loss

* guard on signature columns before removing attention mask

* bump loss

* set remove isn't chainable

* skip slow mistral test in 2.5.1
This commit is contained in:
Wing Lian
2025-07-07 09:35:22 -04:00
committed by GitHub
parent 5a961ecadf
commit 69cd49a7aa
23 changed files with 449 additions and 32 deletions

View File

View File

@@ -0,0 +1,162 @@
"""
monkeypatch for flex + packing
"""
import sys
from typing import Callable, Optional, Union
import torch
from torch.nn.attention.flex_attention import BlockMask
from transformers import Cache, PretrainedConfig
from transformers.masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_preprocess_mask_arguments,
and_masks,
causal_mask_function,
or_masks,
)
from transformers.utils import is_torch_greater_or_equal
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
def create_causal_mask(
config: PretrainedConfig,
input_embeds: torch.Tensor,
attention_mask: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Optional[Cache],
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[torch.Tensor, BlockMask]]:
"""
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
to what is needed in the `modeling_xxx.py` files).
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`torch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
and_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
"""
# If we have an HybridCache structure, here we want to create the mask for the full layers
if (
past_key_values
and hasattr(past_key_values, "is_sliding")
and False in past_key_values.is_sliding
):
layer_idx = past_key_values.is_sliding.index(False)
else:
layer_idx = 0
original_attention_mask = (
None
if attention_mask is None
else attention_mask.clone().to(cache_position.device)
)
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
)
if early_exit:
return attention_mask
batch_size, total_seq_len = cache_position.shape
key_length = total_seq_len
document_ids = torch.nn.functional.pad(
original_attention_mask, value=0, pad=(0, key_length)
)
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask_ = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
final_mask = causal_mask_ & document_mask
return final_mask
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
config._attn_implementation # pylint: disable=protected-access
]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (
not past_key_values.is_compileable if past_key_values is not None else True
)
# Allow slight deviations from causal mask
if or_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError(
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
)
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
# We now create the mask
causal_mask = mask_interface(
batch_size=batch_size,
cache_position=cache_position,
kv_length=kv_length,
kv_offset=kv_offset,
mask_function=mask_factory_function,
attention_mask=attention_mask,
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
)
return causal_mask
def patch_create_causal_mask(model_type):
import transformers.masking_utils
transformers.masking_utils.create_causal_mask = create_causal_mask
if model_type:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(module_path)
module.create_causal_mask = create_causal_mask
del sys.modules[module_path]
except (ImportError, AttributeError) as e:
raise ValueError(
f"Could not import attention class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -245,10 +245,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
self.cfg.flash_attention
or self.cfg.xformers_attention
or self.cfg.flex_attention
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.multipack_real_batches
if self.cfg.multipack_real_batches is not None
else not self.cfg.flash_attention
else not (
self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.xformers_attention
)
)
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing

View File

@@ -27,6 +27,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
CheckpointSaveMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
SchedulerMixin,
)
@@ -42,7 +43,12 @@ LOG = get_logger(__name__)
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
PackingMixin,
SchedulerMixin,
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""
@@ -206,6 +212,14 @@ class AxolotlTrainer(
if dataset.column_names and "length" in dataset.column_names:
dataset = dataset.remove_columns(["length"])
if (
dataset.column_names
and "position_ids" in dataset.column_names
and "attention_mask" in dataset.column_names
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
dataset = dataset.remove_columns(["attention_mask"])
if isinstance(dataset, datasets.Dataset):
if is_training:

View File

@@ -5,5 +5,6 @@
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin

View File

@@ -0,0 +1,20 @@
"""Trainer mixin to support packing"""
from transformers import Trainer
class PackingMixin(Trainer):
"""
Trainer mixin to support packing
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
if (
self._signature_columns
and self.args.sample_packing
and self.args.sample_packing_drop_attention_mask
):
set_sig_columns = set(self._signature_columns)
set_sig_columns.remove("attention_mask")
self._signature_columns = list(set_sig_columns)

View File

@@ -42,6 +42,10 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "The multiprocessing start method to use."},
)
sample_packing_drop_attention_mask: bool = field(
default=False,
metadata={"help": "Drop attention mask from inputs when using packing."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},

View File

@@ -49,11 +49,11 @@ class PatchManager:
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_flex_attention_patches()
self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
@@ -97,6 +97,14 @@ class PatchManager:
patch_accelerate_fsdp2()
# if self.cfg.fsdp_config:
# # see transformers#39152
# from axolotl.monkeypatch.trainer_fsdp_optim import (
# patch_training_loop_for_fsdp,
# )
#
# patch_training_loop_for_fsdp()
def _apply_adapter_patches(self):
"""Apply patches for adapter configurations."""
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
@@ -107,14 +115,20 @@ class PatchManager:
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
patch_flex_wrapper,
)
# from axolotl.monkeypatch.attention.flex_attn import (
# patch_flex_make_mask,
# patch_flex_wrapper,
# )
#
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
# patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_make_mask()
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
patch_flex_make_mask()
patch_create_causal_mask(self.cfg.model_config_type)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""

View File

@@ -33,7 +33,7 @@ RING_ATTN_FUNC_MAPPING = {
}
def create_flash_attn_forward(
def create_flash_attn_forward_varlen_llama3(
process_group: dist.ProcessGroup, ring_attn_func: RingAttnFunc
) -> Callable:
"""
@@ -71,6 +71,7 @@ def create_flash_attn_forward(
max_length_q: int | None = None,
max_length_k: int | None = None,
target_dtype: torch.dtype | None = None,
attn_implementation: str | None = None,
**kwargs,
):
"""
@@ -97,6 +98,7 @@ def create_flash_attn_forward(
max_length_q: Not used in this implementation.
max_length_k: Not used in this implementation.
target_dtype: Not used in this implementation.
attn_implementation: Not used in this implementation.
**kwargs: Additional keyword arguments. Not used in this implementation.
Returns:
@@ -161,7 +163,7 @@ def substitute_hf_flash_attn(
old_flash_attention_forward = (
transformers.modeling_flash_attention_utils._flash_attention_forward
)
new_flash_attention_forward = create_flash_attn_forward(
new_flash_attention_forward = create_flash_attn_forward_varlen_llama3(
process_group=process_group, ring_attn_func=ring_attn_func
)

View File

@@ -9,10 +9,13 @@ sequence parallelism training.
"""
import inspect
import os
from typing import Optional
import accelerate
import torch
import torch.distributed as dist
from transformers.modeling_flash_attention_utils import _flash_supports_window_size
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
@@ -62,6 +65,96 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
RING_ATTN_GROUP = ring_attn_group
def create_ring_flash_attention_forward(
process_group: dist.ProcessGroup, heads_k_stride: int
):
from ring_flash_attn import llama3_flash_attn_varlen_func
from ring_flash_attn.adapters.hf_adapter import DATA_PARAMS
def _flash_attention_forward_v3(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor, # pylint: disable=unused-argument
query_length: int,
is_causal: bool,
dropout: float = 0.0,
position_ids: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seq_lens_q: Optional[
torch.LongTensor
] = None, # pylint: disable=unused-argument
cu_seq_lens_k: Optional[
torch.LongTensor
] = None, # pylint: disable=unused-argument
max_length_q: Optional[int] = None, # pylint: disable=unused-argument
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
target_dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument
attn_implementation: Optional[str] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# pylint: disable=duplicate-code
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
causal = is_causal and query_length != 1
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size
and sliding_window is not None
and key_states.shape[1] > sliding_window
)
flash_kwargs = (
{"window_size": (sliding_window, sliding_window)}
if use_sliding_windows
else {}
)
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = deterministic
assert (
softcap is None
), "llama3_flash_attn_varlen_func does not support softcap yet."
# flash_kwargs["softcap"] = softcap
flash_kwargs["group"] = process_group
# not sure why attention_mask can be not None...
assert causal, "only causal attention is supported yet."
batch_size = query_states.size(0)
assert batch_size == 1, "varlen data should be processed in advance."
attn_output = llama3_flash_attn_varlen_func(
query_states.squeeze(dim=0),
key_states.squeeze(dim=0),
value_states.squeeze(dim=0),
cu_seqlens_q=DATA_PARAMS["cu_seqlens_q"],
cu_seqlens_k=DATA_PARAMS["cu_seqlens_k"],
max_seqlen_q=DATA_PARAMS["max_seqlen_q"],
max_seqlen_k=DATA_PARAMS["max_seqlen_k"],
heads_k_stride=heads_k_stride,
local_k_slice=DATA_PARAMS["local_k_slice"],
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = attn_output.unsqueeze(dim=0)
return attn_output
return [
_flash_attention_forward_v3,
]
def register_ring_attn(
sequence_parallel_degree: int,
heads_k_stride: int | None,
@@ -118,9 +211,20 @@ def register_ring_attn(
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if ring_attn_func is RingAttnFunc.VARLEN_LLAMA3:
from ring_flash_attn import substitute_hf_flash_attn
# fmt: off
import ring_flash_attn.adapters.hf_adapter
substitute_hf_flash_attn(
from ring_flash_attn.adapters.hf_adapter import ( # isort: skip # pylint: disable=unused-import
create_ring_flash_attention_forward as create_ring_flash_attention_forward_orig,
)
create_ring_flash_attention_forward_orig = ( # noqa: F811,F841
create_ring_flash_attention_forward
)
ring_flash_attn.adapters.hf_adapter.create_ring_flash_attention_forward = create_ring_flash_attention_forward
# fmt: on
ring_flash_attn.adapters.hf_adapter.substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1
)
elif ring_attn_func is RingAttnFunc.BATCH_RING:

View File

@@ -12,15 +12,13 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled
if delay_optimizer_creation:
self.optimizer = self.accelerator.prepare(self.optimizer)
"""
PATCHED_TRAINER_CODE = """
delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
if delay_optimizer_creation:
model = self.accelerator.prepare(self.model)
"""

View File

@@ -203,7 +203,7 @@ class AxolotlInputConfig(
},
)
dataset_processes: int | None = Field(
default=min(32, os.cpu_count()), # type: ignore[type-var]
default=min(int(os.environ.get("AXOLOTL_DATASET_PROCESSES", 32)), os.cpu_count()), # type: ignore[type-var]
json_schema_extra={
"description": "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set."
},

View File

@@ -535,6 +535,9 @@ def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
os.environ["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(
cfg.gradient_accumulation_steps
)
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3: