fix: remove cuda alloc for moe and enable async load
This commit is contained in:
@@ -53,17 +53,25 @@ def patch_moe_quantization_on_load(cfg):
|
||||
LOG.debug("MoE loading-time quantization patch already active")
|
||||
return
|
||||
|
||||
import os
|
||||
|
||||
import transformers.core_model_loading
|
||||
import transformers.modeling_utils
|
||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||
|
||||
# Disable transformers' async weight loading thread pool. Without this,
|
||||
# the ThreadPoolExecutor pre-fetches tensors to CUDA faster than the main
|
||||
# loop can quantize them, causing all expert weights to accumulate in bf16
|
||||
# on GPU — defeating the purpose of loading-time quantization.
|
||||
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
||||
LOG.info("Disabled async weight loading (HF_DEACTIVATE_ASYNC_LOAD=1)")
|
||||
# Patch caching_allocator_warmup to be a no-op. This function pre-allocates
|
||||
# a single huge GPU tensor equal to the model's total param bytes to warm the
|
||||
# CUDA caching allocator. For MoE models, it calculates expert params at bf16
|
||||
# size (BnB doesn't know we'll quantize them), causing a ~50+ GiB reservation
|
||||
# that defeats loading-time quantization. Disabling it trades slightly slower
|
||||
# weight loading for dramatically lower peak VRAM.
|
||||
_original_warmup = transformers.modeling_utils.caching_allocator_warmup
|
||||
|
||||
def _noop_warmup(*args, **kwargs):
|
||||
LOG.info(
|
||||
"Skipped caching_allocator_warmup (MoE loading-time quantization active)"
|
||||
)
|
||||
|
||||
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||
LOG.info("Patched caching_allocator_warmup to no-op for MoE quantization")
|
||||
|
||||
# Read quantization settings from config
|
||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||
@@ -77,44 +85,27 @@ def patch_moe_quantization_on_load(cfg):
|
||||
|
||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||
|
||||
_first_call = [True]
|
||||
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
if _first_call[0]:
|
||||
LOG.info(
|
||||
"MoE quant patch: set_param_for_module intercepted (first call) "
|
||||
"(alloc=%.2f GiB, reserved=%.2f GiB, max_alloc=%.2f GiB)",
|
||||
torch.cuda.memory_allocated() / 1024**3,
|
||||
torch.cuda.memory_reserved() / 1024**3,
|
||||
torch.cuda.max_memory_allocated() / 1024**3,
|
||||
)
|
||||
_first_call[0] = False
|
||||
|
||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||
|
||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||
if param_value.ndim >= 3:
|
||||
LOG.info(
|
||||
"MoE quant patch: 3D param %s shape=%s cuda=%s",
|
||||
target_name,
|
||||
param_value.shape,
|
||||
param_value.is_cuda,
|
||||
)
|
||||
if param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
_moe_load_state["count"] += 1
|
||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
_moe_load_state["count"] += 1
|
||||
if _moe_load_state["count"] % 10 == 1:
|
||||
LOG.info(
|
||||
"Quantized 3D expert param: %s "
|
||||
"Quantized expert param #%d: %s "
|
||||
"(alloc=%.2f GiB, reserved=%.2f GiB)",
|
||||
_moe_load_state["count"],
|
||||
target_name,
|
||||
torch.cuda.memory_allocated() / 1024**3,
|
||||
torch.cuda.memory_reserved() / 1024**3,
|
||||
@@ -122,11 +113,6 @@ def patch_moe_quantization_on_load(cfg):
|
||||
|
||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||
_moe_load_state["patched"] = True
|
||||
LOG.info(
|
||||
"Pre-load GPU memory: alloc=%.2f GiB, reserved=%.2f GiB",
|
||||
torch.cuda.memory_allocated() / 1024**3,
|
||||
torch.cuda.memory_reserved() / 1024**3,
|
||||
)
|
||||
LOG.info(
|
||||
"Activated MoE loading-time quantization patch "
|
||||
"(quant_type=%s, compress_statistics=%s)",
|
||||
|
||||
Reference in New Issue
Block a user