fix: remove cuda alloc for moe and enable async load
This commit is contained in:
@@ -53,17 +53,25 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
LOG.debug("MoE loading-time quantization patch already active")
|
LOG.debug("MoE loading-time quantization patch already active")
|
||||||
return
|
return
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import transformers.core_model_loading
|
import transformers.core_model_loading
|
||||||
|
import transformers.modeling_utils
|
||||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||||
|
|
||||||
# Disable transformers' async weight loading thread pool. Without this,
|
# Patch caching_allocator_warmup to be a no-op. This function pre-allocates
|
||||||
# the ThreadPoolExecutor pre-fetches tensors to CUDA faster than the main
|
# a single huge GPU tensor equal to the model's total param bytes to warm the
|
||||||
# loop can quantize them, causing all expert weights to accumulate in bf16
|
# CUDA caching allocator. For MoE models, it calculates expert params at bf16
|
||||||
# on GPU — defeating the purpose of loading-time quantization.
|
# size (BnB doesn't know we'll quantize them), causing a ~50+ GiB reservation
|
||||||
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
# that defeats loading-time quantization. Disabling it trades slightly slower
|
||||||
LOG.info("Disabled async weight loading (HF_DEACTIVATE_ASYNC_LOAD=1)")
|
# weight loading for dramatically lower peak VRAM.
|
||||||
|
_original_warmup = transformers.modeling_utils.caching_allocator_warmup
|
||||||
|
|
||||||
|
def _noop_warmup(*args, **kwargs):
|
||||||
|
LOG.info(
|
||||||
|
"Skipped caching_allocator_warmup (MoE loading-time quantization active)"
|
||||||
|
)
|
||||||
|
|
||||||
|
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||||
|
LOG.info("Patched caching_allocator_warmup to no-op for MoE quantization")
|
||||||
|
|
||||||
# Read quantization settings from config
|
# Read quantization settings from config
|
||||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||||
@@ -77,44 +85,27 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
|
|
||||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||||
|
|
||||||
_first_call = [True]
|
|
||||||
|
|
||||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||||
if _first_call[0]:
|
|
||||||
LOG.info(
|
|
||||||
"MoE quant patch: set_param_for_module intercepted (first call) "
|
|
||||||
"(alloc=%.2f GiB, reserved=%.2f GiB, max_alloc=%.2f GiB)",
|
|
||||||
torch.cuda.memory_allocated() / 1024**3,
|
|
||||||
torch.cuda.memory_reserved() / 1024**3,
|
|
||||||
torch.cuda.max_memory_allocated() / 1024**3,
|
|
||||||
)
|
|
||||||
_first_call[0] = False
|
|
||||||
|
|
||||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||||
|
|
||||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||||
if param_value.ndim >= 3:
|
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||||
LOG.info(
|
mod_path, _, pname = target_name.rpartition(".")
|
||||||
"MoE quant patch: 3D param %s shape=%s cuda=%s",
|
mod = model.get_submodule(mod_path) if mod_path else model
|
||||||
target_name,
|
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||||
param_value.shape,
|
replace_parameter_4bit(
|
||||||
param_value.is_cuda,
|
mod,
|
||||||
)
|
pname,
|
||||||
if param_value.is_cuda:
|
compress_statistics=_moe_load_state["compress_statistics"],
|
||||||
mod_path, _, pname = target_name.rpartition(".")
|
quant_type=_moe_load_state["quant_type"],
|
||||||
mod = model.get_submodule(mod_path) if mod_path else model
|
)
|
||||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
torch.cuda.empty_cache()
|
||||||
replace_parameter_4bit(
|
_moe_load_state["count"] += 1
|
||||||
mod,
|
if _moe_load_state["count"] % 10 == 1:
|
||||||
pname,
|
|
||||||
compress_statistics=_moe_load_state["compress_statistics"],
|
|
||||||
quant_type=_moe_load_state["quant_type"],
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
_moe_load_state["count"] += 1
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Quantized 3D expert param: %s "
|
"Quantized expert param #%d: %s "
|
||||||
"(alloc=%.2f GiB, reserved=%.2f GiB)",
|
"(alloc=%.2f GiB, reserved=%.2f GiB)",
|
||||||
|
_moe_load_state["count"],
|
||||||
target_name,
|
target_name,
|
||||||
torch.cuda.memory_allocated() / 1024**3,
|
torch.cuda.memory_allocated() / 1024**3,
|
||||||
torch.cuda.memory_reserved() / 1024**3,
|
torch.cuda.memory_reserved() / 1024**3,
|
||||||
@@ -122,11 +113,6 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
|
|
||||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||||
_moe_load_state["patched"] = True
|
_moe_load_state["patched"] = True
|
||||||
LOG.info(
|
|
||||||
"Pre-load GPU memory: alloc=%.2f GiB, reserved=%.2f GiB",
|
|
||||||
torch.cuda.memory_allocated() / 1024**3,
|
|
||||||
torch.cuda.memory_reserved() / 1024**3,
|
|
||||||
)
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Activated MoE loading-time quantization patch "
|
"Activated MoE loading-time quantization patch "
|
||||||
"(quant_type=%s, compress_statistics=%s)",
|
"(quant_type=%s, compress_statistics=%s)",
|
||||||
|
|||||||
Reference in New Issue
Block a user