fix: remove cuda alloc for moe and enable async load

This commit is contained in:
NanoCode012
2026-02-25 18:01:58 +07:00
parent ca822cd24c
commit 6ad4b4ecbe

View File

@@ -53,17 +53,25 @@ def patch_moe_quantization_on_load(cfg):
LOG.debug("MoE loading-time quantization patch already active")
return
import os
import transformers.core_model_loading
import transformers.modeling_utils
from bitsandbytes.nn.parametrize import replace_parameter_4bit
# Disable transformers' async weight loading thread pool. Without this,
# the ThreadPoolExecutor pre-fetches tensors to CUDA faster than the main
# loop can quantize them, causing all expert weights to accumulate in bf16
# on GPU — defeating the purpose of loading-time quantization.
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
LOG.info("Disabled async weight loading (HF_DEACTIVATE_ASYNC_LOAD=1)")
# Patch caching_allocator_warmup to be a no-op. This function pre-allocates
# a single huge GPU tensor equal to the model's total param bytes to warm the
# CUDA caching allocator. For MoE models, it calculates expert params at bf16
# size (BnB doesn't know we'll quantize them), causing a ~50+ GiB reservation
# that defeats loading-time quantization. Disabling it trades slightly slower
# weight loading for dramatically lower peak VRAM.
_original_warmup = transformers.modeling_utils.caching_allocator_warmup
def _noop_warmup(*args, **kwargs):
LOG.info(
"Skipped caching_allocator_warmup (MoE loading-time quantization active)"
)
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
LOG.info("Patched caching_allocator_warmup to no-op for MoE quantization")
# Read quantization settings from config
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
@@ -77,44 +85,27 @@ def patch_moe_quantization_on_load(cfg):
original_set_param = transformers.core_model_loading.set_param_for_module
_first_call = [True]
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
if _first_call[0]:
LOG.info(
"MoE quant patch: set_param_for_module intercepted (first call) "
"(alloc=%.2f GiB, reserved=%.2f GiB, max_alloc=%.2f GiB)",
torch.cuda.memory_allocated() / 1024**3,
torch.cuda.memory_reserved() / 1024**3,
torch.cuda.max_memory_allocated() / 1024**3,
)
_first_call[0] = False
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3:
LOG.info(
"MoE quant patch: 3D param %s shape=%s cuda=%s",
target_name,
param_value.shape,
param_value.is_cuda,
)
if param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1
if _moe_load_state["count"] % 10 == 1:
LOG.info(
"Quantized 3D expert param: %s "
"Quantized expert param #%d: %s "
"(alloc=%.2f GiB, reserved=%.2f GiB)",
_moe_load_state["count"],
target_name,
torch.cuda.memory_allocated() / 1024**3,
torch.cuda.memory_reserved() / 1024**3,
@@ -122,11 +113,6 @@ def patch_moe_quantization_on_load(cfg):
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True
LOG.info(
"Pre-load GPU memory: alloc=%.2f GiB, reserved=%.2f GiB",
torch.cuda.memory_allocated() / 1024**3,
torch.cuda.memory_reserved() / 1024**3,
)
LOG.info(
"Activated MoE loading-time quantization patch "
"(quant_type=%s, compress_statistics=%s)",