Files
axolotl/src/axolotl/loaders/patch_manager.py
Avaya Aggarwal 7ddfb2d8a0 cleanup: remove dead SDPA patches (#3488) [skip ci]
Transformers 5.x routes attention through sdpa_attention.py and no longer
calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that
these patches targeted. This makes the following patches dead code:

- llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*)
- llama_expand_mask.py (patched _expand_mask, never called)
- Related utility functions in monkeypatch/utils.py

Closes axolotl-ai-cloud/axolotl#3331
2026-03-20 17:10:41 +07:00

717 lines
26 KiB
Python

"""Patch manager class implementation to complement `axolotl.loaders.ModelLoader`.
Applies pre- and post-model load patches for various fixes and optimizations.
"""
import importlib.util
import os
from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
class PatchManager:
"""Manages the application of patches during the model loading process."""
@staticmethod
def apply_pre_config_load_patches(cfg: DictDefault):
"""
Apply patches that must be set up before config loading.
This is for patches that intercept remote code loading from HuggingFace,
which needs to be in place before AutoConfig.from_pretrained() is called.
Args:
cfg: Configuration dictionary with model and training settings.
"""
if (
hasattr(cfg, "base_model_config")
and cfg.base_model_config
and "kimi-linear" in cfg.base_model_config.lower()
):
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_config,
)
patch_kimi_config()
@staticmethod
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
"""
Apply patches that must be set up before tokenizer loading.
This is for patches that intercept remote code loading from HuggingFace,
which needs to be in place before AutoTokenizer.from_pretrained() is called.
Args:
cfg: Configuration dictionary with model and training settings.
"""
if (
hasattr(cfg, "tokenizer_config")
and cfg.tokenizer_config
and "kimi-linear" in cfg.tokenizer_config.lower()
):
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_tokenizer,
)
patch_kimi_tokenizer()
def __init__(
self,
cfg: DictDefault,
model_config: PretrainedConfig | addict.Dict,
inference: bool = False,
):
"""Initialize the `PatchManager`.
Args:
cfg: Configuration dictionary with model and training settings.
model_config: Configuration object for the model.
inference: Whether the model is being loaded for inference mode.
"""
self.cfg = cfg
self.model_config = model_config
self.inference = inference
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed."""
return importlib.util.find_spec("flash_attn") is not None
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._deactivate_hf_async_load()
self._apply_transformers_patches()
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_flash_attn_4_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches()
self._patch_attention()
self._apply_multipack_patches()
self._patch_loss_llama()
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
self._apply_patch_deepspeed_zero3()
self._apply_voxtral_patches()
self._apply_apertus_patches()
self._apply_trl_vllm_patches()
self._apply_trl_trainer_utils_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
)
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1:
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
def _apply_chunked_cross_entropy_patch(self):
if self.cfg.chunked_cross_entropy:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
else:
patch_chunked_ce_loss_fn()
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.fsdp_config:
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_initialize_missing_keys_for_fsdp,
)
patch_initialize_missing_keys_for_fsdp()
if self.cfg.context_parallel_size > 1 or (
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
):
from axolotl.monkeypatch.accelerate.parallelism_config import (
patch_parallelism_config,
)
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_accelerate_fsdp2,
patch_tied_keys_for_meta_device,
)
patch_accelerate_fsdp2()
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
patch_tied_keys_for_meta_device()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
patch_trl_prepare_fsdp2()
# if self.cfg.fsdp_config:
# # see transformers#39152
# from axolotl.monkeypatch.trainer_fsdp_optim import (
# patch_training_loop_for_fsdp,
# )
#
# patch_training_loop_for_fsdp()
def _apply_adapter_patches(self):
"""Apply patches for adapter configurations."""
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_flash_attn_4_patches(self):
"""Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+."""
if not self.cfg.flash_attention:
return
from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4
patch_flash_attn_4(self.model_config)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
self.cfg.model_config_type == "llama4"
and self.cfg.llama4_linearized_experts
):
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_modeling_packing,
)
patch_qwen3_5_modeling_packing()
if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_moe_modeling_packing,
)
patch_qwen3_5_moe_modeling_packing()
if (
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
and self.cfg.is_multimodal
and self.cfg.flash_attention
):
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_vlm_flash_attention,
)
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
)
patch_kimi_model()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)
def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT."""
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
)
patch_fa_peft_integration()
def _apply_gradient_checkpointing_patches(self):
"""Apply patches for gradient checkpointing."""
if (
self.cfg.gradient_checkpointing
and self.cfg.activation_offloading == "legacy"
):
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_offload_wrapper,
)
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
elif (
self.cfg.gradient_checkpointing
and self.cfg.activation_offloading == "offload_disk"
):
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
)
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
def _apply_mistral_cross_entropy_patch(self):
"""Apply Mistral cross entropy patch if configured."""
if (
self.cfg.model_config_type == "mistral"
and self.cfg.flash_attn_cross_entropy_loss
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
patch_mistral_cross_entropy,
)
patch_mistral_cross_entropy()
def _apply_self_attention_lora_patch(self):
"""Apply self-attention LoRA patches if configured."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
# Only patch if conditions are met
can_patch = (
self.cfg.lora_dropout == 0
if hasattr(self.cfg, "lora_dropout")
else True
) # default to True if lora_dropout is not set
if not can_patch:
LOG.warning("Cannot patch self-attention - requires no dropout")
return
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def _apply_multipack_patches(self):
"""Apply multipack patches if necessary."""
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
# Get automap config if it exists
auto_map_config = None
if isinstance(self.model_config, dict) and "auto_map" in self.model_config:
auto_map_config = self.model_config["auto_map"]
elif hasattr(self.model_config, "auto_map"):
auto_map_config = self.model_config.auto_map
# Determine if the model has remote code
if auto_map_config is not None:
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is not None:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code
patch_for_multipack(
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
)
if self.cfg.sample_packing:
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
apply_multipack_dataloader_patch,
)
LOG.info("Applying multipack dataloader patch for sample packing...")
apply_multipack_dataloader_patch()
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_dtype_attrs_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
apply_linear8bitlt_save_patch,
)
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
apply_init_dtype_attrs_patch()
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _deactivate_hf_async_load(self):
"""Load weights synchronously so they can be converted and not OOM."""
if self.cfg.load_in_4bit or self.cfg.load_in_8bit:
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading and PEFT for MoE expert quantization."""
has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None))
if not self.cfg.quantize_moe_experts and not has_target_params:
return
from axolotl.monkeypatch.moe_quant import (
patch_peft_target_parameters_matching,
)
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
"""Log quantization results and set model flag for downstream use."""
import torch
model._moe_experts_quantized = False
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
count = get_moe_quantized_count()
if count > 0:
import gc
model._moe_experts_quantized = True
LOG.info(
"Quantized %d MoE expert parameter(s) to %s during model loading",
count,
"4-bit" if self.cfg.load_in_4bit else "8-bit",
)
gc.collect()
torch.cuda.empty_cache()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import (
patch_tiled_mlp,
)
patch_tiled_mlp(
model_type,
use_original_mlp=self.cfg.tiled_mlp_use_original_mlp,
cfg_num_shards=self.cfg.tiled_mlp_num_shards,
)
def _apply_voxtral_patches(self):
"""Apply patches for Voxtral model."""
if self.cfg.model_config_type == "voxtral":
from axolotl.monkeypatch.models.voxtral.modeling import (
patch_voxtral_conditional_generation_forward,
)
patch_voxtral_conditional_generation_forward()
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
return
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type == "stablelm_epoch" and self.cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type in ("mistral3", "llava"):
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
apply_patch_is_packed_sequence,
)
apply_patch_is_packed_sequence()
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:
return
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
)
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def _patch_llama_flash_attention(self):
"""Apply Flash Attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if self.cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
def _patch_llama_xformers_attention(self):
"""Apply xformers attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("Patching with xformers attention...")
hijack_llama_attention()
def _patch_llama_derived_model(self):
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
if self.cfg.flash_attention:
self._patch_llama_flash_attention()
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
def _apply_llama_flash_attn_patches(self, model):
"""Apply LLaMA-specific flash attention patches."""
if (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference
):
# TODO(MengqingCao): split these patches separately
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
def _apply_unsloth_patches(self, model):
"""Apply unsloth optimization patches."""
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
integrate_lora_mlp_patch(peft_model=model)
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
integrate_lora_patch(peft_model=model, cfg=self.cfg)
if self.cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
def _apply_lora_kernel_patch(self, model):
"""Apply LoRA kernel patches."""
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
or self.cfg.lora_o_kernel
):
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(model=model, cfg=self.cfg)
def _apply_patch_deepspeed_zero3(self):
try:
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
if self.cfg.activation_offloading is True and (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
):
apply_deepspeed_patches()
except ImportError as e:
LOG.warning(f"DeepSpeed patches not applied: {e}")
def _apply_apertus_patches(self):
"""Apply patches for Apertus model."""
if self.cfg.model_config_type == "apertus":
from axolotl.monkeypatch.models.apertus.activation import (
patch_apertus_xielu_activation,
)
patch_apertus_xielu_activation()
def _apply_trl_vllm_patches(self):
"""Apply TRL vLLM patches for batched weight sync, NaN logprobs fix, and scalar handling."""
if (
self.cfg.rl
and getattr(self.cfg, "trl", None)
and getattr(self.cfg.trl, "use_vllm", False)
):
from axolotl.monkeypatch.trainer.trl_vllm import patch_trl_vllm
patch_trl_vllm()
def _apply_trl_trainer_utils_patches(self):
"""Replace trl.trainer.utils.{selective_log_softmax, entropy_from_logits} with Triton kernels."""
if not self.cfg.rl:
return
try:
from axolotl.monkeypatch.trainer.utils import (
entropy_from_logits,
selective_log_softmax,
)
except (ImportError, ModuleNotFoundError):
LOG.warning("Triton not available — skipping trl.trainer.utils patches")
return
import trl.trainer.utils
# Guard against repeated calls: only stash the original if trl still
# points at its own implementation (not our wrapper).
if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
from axolotl.monkeypatch.trainer import utils as _axolotl_trainer_utils
_axolotl_trainer_utils.selective_log_softmax_original = (
trl.trainer.utils.selective_log_softmax
)
trl.trainer.utils.selective_log_softmax = selective_log_softmax
if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
trl.trainer.utils.entropy_from_logits = entropy_from_logits
LOG.info(
"Patched trl.trainer.utils with Triton selective_log_softmax and entropy_from_logits"
)
def _apply_scaling_softmax_patch(self, model: PreTrainedModel):
"""Apply Scaling Softmax (SSMax) patch. Ref: https://arxiv.org/abs/2501.19399"""
if self.cfg.scaling_softmax:
from axolotl.monkeypatch.scaled_softmax_attn import (
patch_scaled_softmax_attention,
)
patch_scaled_softmax_attention(
scaling_factor_init=self.cfg.scaling_softmax_factor or 0.43,
bias=self.cfg.scaling_softmax_bias or 0.0,
model=model,
)