From 1558436c69efb7a029140f5e2e1cc1404f50a7c9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 25 Feb 2026 17:30:32 +0700 Subject: [PATCH] fix: attempt disable async load --- src/axolotl/monkeypatch/moe_quant.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index ad81556da..b16d30e13 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -53,9 +53,18 @@ def patch_moe_quantization_on_load(cfg): LOG.debug("MoE loading-time quantization patch already active") return + import os + import transformers.core_model_loading from bitsandbytes.nn.parametrize import replace_parameter_4bit + # Disable transformers' async weight loading thread pool. Without this, + # the ThreadPoolExecutor pre-fetches tensors to CUDA faster than the main + # loop can quantize them, causing all expert weights to accumulate in bf16 + # on GPU — defeating the purpose of loading-time quantization. + os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1" + LOG.info("Disabled async weight loading (HF_DEACTIVATE_ASYNC_LOAD=1)") + # Read quantization settings from config quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)