bug-fix: only apply patches when CUDA is available (#3561)

* bug-fix: only apply patches when CUDA is available

This will otherwise crash when performing operations with CUDA_VISIBLE_DEVICES=, such as LoRA merging on CPU.

This patch only patches the Qwen 3.5 model, since that's the only one I've tested. This patch should most likely check torch.cuda for all other models as well. One limitation here is that I'm assuming the user runs CUDA, but that assumption is not restricted to this patch so it is probably fine.

* include patch_qwen3_next_modeling_packing, patch_qwen3_5_moe_modeling_packing, and patch_qwen3_5_vlm_flash_attention in cuda guard
This commit is contained in:
kallewoof
2026-04-01 08:05:15 +09:00
committed by GitHub
parent a81feabbd9
commit a4c94416eb

View File

@@ -8,6 +8,7 @@ import os
from functools import cached_property
import addict
import torch
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
@@ -258,38 +259,6 @@ class PatchManager:
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_modeling_packing,
)
patch_qwen3_5_modeling_packing()
if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_moe_modeling_packing,
)
patch_qwen3_5_moe_modeling_packing()
if (
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
and self.cfg.is_multimodal
and self.cfg.flash_attention
):
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_vlm_flash_attention,
)
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
@@ -314,6 +283,40 @@ class PatchManager:
# False because the original block forward is not GC-safe.
NemotronHPreTrainedModel.supports_gradient_checkpointing = True
# Patches requiring CUDA
if torch.cuda.is_available():
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_modeling_packing,
)
patch_qwen3_5_modeling_packing()
if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_moe_modeling_packing,
)
patch_qwen3_5_moe_modeling_packing()
if (
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
and self.cfg.is_multimodal
and self.cfg.flash_attention
):
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_vlm_flash_attention,
)
patch_qwen3_5_vlm_flash_attention()
@staticmethod
def _fix_nemotron_h_conversion_mapping():
"""Remove the spurious embedding→embeddings WeightRenaming from the