Files
axolotl/src/axolotl/loaders/patch_manager.py
2025-08-11 09:34:41 -04:00

468 lines
17 KiB
Python

"""Patch manager class implementation to complement `axolotl.loaders.ModelLoader`.
Applies pre- and post-model load patches for various fixes and optimizations.
"""
import importlib.util
from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
class PatchManager:
"""Manages the application of patches during the model loading process."""
def __init__(
self,
cfg: DictDefault,
model_config: PretrainedConfig | addict.Dict,
inference: bool = False,
):
"""Initialize the `PatchManager`.
Args:
cfg: Configuration dictionary with model and training settings.
model_config: Configuration object for the model.
inference: Whether the model is being loaded for inference mode.
"""
self.cfg = cfg
self.model_config = model_config
self.inference = inference
@cached_property
def has_flash_attn(self) -> bool:
"""Check if flash attention is installed."""
return importlib.util.find_spec("flash_attn") is not None
def apply_pre_model_load_patches(self):
"""Apply pre-model load patches based on config."""
self._apply_transformers_patches()
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches()
self._patch_attention()
self._apply_multipack_patches()
self._patch_loss_llama()
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
)
patch_fsdp2 = (
self.cfg.torch_compile
and self.cfg.fsdp_config
and self.cfg.fsdp_version == 2
)
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
self._apply_unsloth_patches(model)
self._apply_lora_kernel_patch(model)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
def _apply_chunked_cross_entropy_patch(self):
if self.cfg.chunked_cross_entropy:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
else:
patch_chunked_ce_loss_fn()
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.context_parallel_size > 1 or (
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
):
from axolotl.monkeypatch.accelerate.parallelism_config import (
patch_parallelism_config,
)
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
patch_accelerate_fsdp2()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
patch_trl_prepare_fsdp2()
# if self.cfg.fsdp_config:
# # see transformers#39152
# from axolotl.monkeypatch.trainer_fsdp_optim import (
# patch_training_loop_for_fsdp,
# )
#
# patch_training_loop_for_fsdp()
def _apply_adapter_patches(self):
"""Apply patches for adapter configurations."""
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
# from axolotl.monkeypatch.attention.flex_attn import (
# patch_flex_make_mask,
# patch_flex_wrapper,
# )
#
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
# patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_make_mask()
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,
)
patch_create_causal_mask(self.cfg.model_config_type)
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
self.cfg.model_config_type == "llama4"
and self.cfg.llama4_linearized_experts
):
from axolotl.monkeypatch.models.llama4.modeling import (
patch_llama4_linearized_modeling,
)
patch_llama4_linearized_modeling()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:
from axolotl.monkeypatch.trainer_accelerator_args import (
patch_create_accelerate_code_for_fp8,
)
patch_create_accelerate_code_for_fp8(
self.cfg.fp8_enable_fsdp_float8_all_gather
)
def _apply_flash_attention_peft_patches(self):
"""Apply patches for Flash Attention with PEFT."""
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
)
patch_fa_peft_integration()
def _apply_gradient_checkpointing_patches(self):
"""Apply patches for gradient checkpointing."""
if (
self.cfg.gradient_checkpointing
and self.cfg.activation_offloading == "legacy"
):
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_offload_wrapper,
)
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
elif (
self.cfg.gradient_checkpointing
and self.cfg.activation_offloading == "offload_disk"
):
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
)
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
def _apply_mistral_cross_entropy_patch(self):
"""Apply Mistral cross entropy patch if configured."""
if (
self.cfg.model_config_type == "mistral"
and self.cfg.flash_attn_cross_entropy_loss
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
patch_mistral_cross_entropy,
)
patch_mistral_cross_entropy()
def _apply_self_attention_lora_patch(self):
"""Apply self-attention LoRA patches if configured."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
# Only patch if conditions are met
can_patch = (
self.cfg.lora_dropout == 0
if hasattr(self.cfg, "lora_dropout")
else True
) # default to True if lora_dropout is not set
if not can_patch:
LOG.warning("Cannot patch self-attention - requires no dropout")
return
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def _apply_multipack_patches(self):
"""Apply multipack patches if necessary."""
if (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
# Get automap config if it exists
auto_map_config = None
if isinstance(self.model_config, dict) and "auto_map" in self.model_config:
auto_map_config = self.model_config["auto_map"]
elif hasattr(self.model_config, "auto_map"):
auto_map_config = self.model_config.auto_map
# Determine if the model has remote code
if auto_map_config is not None:
has_remote_code = "AutoModelForCausalLM" in auto_map_config
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is False:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code
patch_for_multipack(
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
)
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import (
patch_tiled_mlp,
)
patch_tiled_mlp(
model_type,
use_original_mlp=self.cfg.tiled_mlp_use_original_mlp,
cfg_num_shards=self.cfg.tiled_mlp_num_shards,
)
def _apply_voxtral_patches(self):
"""Apply patches for Voxtral model."""
if self.cfg.model_config_type == "voxtral":
from axolotl.monkeypatch.models.voxtral.modeling import (
patch_voxtral_conditional_generation_forward,
)
patch_voxtral_conditional_generation_forward()
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
return
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
if self.model_config.model_type == "stablelm_epoch" and self.cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
def _patch_loss_llama(self):
"""Patch loss functions and other optimizations for LLaMA models."""
if not self.cfg.is_llama_derived_model:
return
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_fa_llama_cross_entropy,
)
patch_fa_llama_cross_entropy()
elif self.cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
patch_llama_rms_norm()
elif self.cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
def _patch_llama_flash_attention(self):
"""Apply Flash Attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if self.cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
def _patch_llama_xformers_attention(self):
"""Apply xformers attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("Patching with xformers attention...")
hijack_llama_attention()
def _patch_llama_sample_packing(self):
"""Apply sample packing patches for LLaMA models."""
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
)
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
hijack_llama_prepare_4d_mask()
def _patch_llama_derived_model(self):
"""Modify all llama derived models in one block."""
if self.cfg.is_llama_derived_model and not (
self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and (self.cfg.flash_attention or self.cfg.flex_attention)
and self.cfg.sample_packing
):
if self.cfg.flash_attention:
self._patch_llama_flash_attention()
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
self._patch_llama_sample_packing()
elif self.cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
)
def _apply_llama_flash_attn_patches(self, model):
"""Apply LLaMA-specific flash attention patches."""
if (
self.model_config.model_type in ["llama", "llama4"]
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and not self.inference
):
# TODO(MengqingCao): split these patches seperately
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
def _apply_unsloth_patches(self, model):
"""Apply unsloth optimization patches."""
if self.cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
integrate_lora_mlp_patch(peft_model=model)
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
integrate_lora_patch(peft_model=model, cfg=self.cfg)
if self.cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
def _apply_lora_kernel_patch(self, model):
"""Apply LoRA kernel patches."""
if (
self.cfg.lora_mlp_kernel
or self.cfg.lora_qkv_kernel
or self.cfg.lora_o_kernel
):
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(model=model, cfg=self.cfg)