fix: disable async load when loading quantized bnb
This commit is contained in:
@@ -93,6 +93,7 @@ class PatchManager:
|
|||||||
|
|
||||||
def apply_pre_model_load_patches(self):
|
def apply_pre_model_load_patches(self):
|
||||||
"""Apply pre-model load patches based on config."""
|
"""Apply pre-model load patches based on config."""
|
||||||
|
self._deactivate_hf_async_load()
|
||||||
self._apply_transformers_patches()
|
self._apply_transformers_patches()
|
||||||
# self._apply_flex_attention_patches()
|
# self._apply_flex_attention_patches()
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
@@ -409,6 +410,11 @@ class PatchManager:
|
|||||||
if self.cfg.load_in_8bit:
|
if self.cfg.load_in_8bit:
|
||||||
apply_linear8bitlt_save_patch()
|
apply_linear8bitlt_save_patch()
|
||||||
|
|
||||||
|
def _deactivate_hf_async_load(self):
|
||||||
|
"""Load weights synchronously so they can be converted and not OOM."""
|
||||||
|
if self.cfg.load_in_4bit or self.cfg.load_in_8bit:
|
||||||
|
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
||||||
|
|
||||||
def _apply_moe_expert_quantization_patch(self):
|
def _apply_moe_expert_quantization_patch(self):
|
||||||
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
||||||
if not self.cfg.quantize_moe_experts:
|
if not self.cfg.quantize_moe_experts:
|
||||||
|
|||||||
@@ -7,8 +7,6 @@ on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametriz
|
|||||||
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.utils.parametrize as P
|
import torch.nn.utils.parametrize as P
|
||||||
@@ -103,14 +101,6 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
_moe_load_state["quant_type"] = quant_type
|
_moe_load_state["quant_type"] = quant_type
|
||||||
_moe_load_state["compress_statistics"] = compress_statistics
|
_moe_load_state["compress_statistics"] = compress_statistics
|
||||||
|
|
||||||
# Disable async tensor loading. Transformers' convert_and_load_state_dict_in_model
|
|
||||||
# uses a ThreadPoolExecutor to materialise tensors (move from safetensors → CUDA)
|
|
||||||
# ahead of time. With MoE models this pre-fetches many large bf16 expert tensors
|
|
||||||
# onto the GPU simultaneously — long before our set_param_for_module patch can
|
|
||||||
# quantise and free them one-by-one — causing OOM even at <5 % of weights loaded.
|
|
||||||
# Sequential loading ensures only ONE bf16 expert tensor is on-GPU at a time.
|
|
||||||
os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1"
|
|
||||||
|
|
||||||
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
|
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
|
||||||
# size for all params, defeating our on-load quantization VRAM savings.
|
# size for all params, defeating our on-load quantization VRAM savings.
|
||||||
def _noop_warmup(*args, **kwargs):
|
def _noop_warmup(*args, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user