diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 50c3b85f4..aa67c3b1e 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -8,6 +8,7 @@ import os from functools import cached_property import addict +import torch import transformers from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_flash_attention_utils import is_flash_attn_available @@ -258,38 +259,6 @@ class PatchManager: patch_llama4_linearized_modeling() - if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing: - from axolotl.monkeypatch.models.qwen3_next.modeling import ( - patch_qwen3_next_modeling_packing, - ) - - patch_qwen3_next_modeling_packing() - - if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing: - from axolotl.monkeypatch.models.qwen3_5.modeling import ( - patch_qwen3_5_modeling_packing, - ) - - patch_qwen3_5_modeling_packing() - - if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing: - from axolotl.monkeypatch.models.qwen3_5.modeling import ( - patch_qwen3_5_moe_modeling_packing, - ) - - patch_qwen3_5_moe_modeling_packing() - - if ( - self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"] - and self.cfg.is_multimodal - and self.cfg.flash_attention - ): - from axolotl.monkeypatch.models.qwen3_5.modeling import ( - patch_qwen3_5_vlm_flash_attention, - ) - - patch_qwen3_5_vlm_flash_attention() - if self.cfg.model_config_type == "kimi_linear": from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import ( patch_kimi_model, @@ -314,6 +283,40 @@ class PatchManager: # False because the original block forward is not GC-safe. NemotronHPreTrainedModel.supports_gradient_checkpointing = True + # Patches requiring CUDA + if torch.cuda.is_available(): + if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing: + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_modeling_packing, + ) + + patch_qwen3_next_modeling_packing() + + if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing: + from axolotl.monkeypatch.models.qwen3_5.modeling import ( + patch_qwen3_5_modeling_packing, + ) + + patch_qwen3_5_modeling_packing() + + if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing: + from axolotl.monkeypatch.models.qwen3_5.modeling import ( + patch_qwen3_5_moe_modeling_packing, + ) + + patch_qwen3_5_moe_modeling_packing() + + if ( + self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"] + and self.cfg.is_multimodal + and self.cfg.flash_attention + ): + from axolotl.monkeypatch.models.qwen3_5.modeling import ( + patch_qwen3_5_vlm_flash_attention, + ) + + patch_qwen3_5_vlm_flash_attention() + @staticmethod def _fix_nemotron_h_conversion_mapping(): """Remove the spurious embedding→embeddings WeightRenaming from the