From 43b1c80aa644c0f65b58952e8262321b0b27b489 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 7 Mar 2026 07:09:24 -0500 Subject: [PATCH] load weights synchronously so they can be converted and not OOM: (#3477) --- src/axolotl/monkeypatch/moe_quant.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index 42beec6a9..a8ad4c371 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -7,6 +7,8 @@ on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametriz reducing peak VRAM from "all experts in bf16" to "one expert at a time." """ +import os + import bitsandbytes as bnb import torch import torch.nn.utils.parametrize as P @@ -101,6 +103,14 @@ def patch_moe_quantization_on_load(cfg): _moe_load_state["quant_type"] = quant_type _moe_load_state["compress_statistics"] = compress_statistics + # Disable async tensor loading. Transformers' convert_and_load_state_dict_in_model + # uses a ThreadPoolExecutor to materialise tensors (move from safetensors → CUDA) + # ahead of time. With MoE models this pre-fetches many large bf16 expert tensors + # onto the GPU simultaneously — long before our set_param_for_module patch can + # quantise and free them one-by-one — causing OOM even at <5 % of weights loaded. + # Sequential loading ensures only ONE bf16 expert tensor is on-GPU at a time. + os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1" + # Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16 # size for all params, defeating our on-load quantization VRAM savings. def _noop_warmup(*args, **kwargs):