MoE Expert Quantization
+Transformers v5 changed MoE expert layers from nn.Linear to fused nn.Parameter (3D+ tensors).
+This means bitsandbytes can no longer quantize them during model loading, resulting in all expert
+weights being loaded in full bf16 precision and causing massive VRAM usage.
quantize_moe_experts solves this by quantizing expert weights during model loading.
+It intercepts the weight loading process, quantizes each expert tensor on the fly, and
+immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
+For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
Usage
+Enable expert quantization in your Axolotl config:
+quantize_moe_experts: trueThis works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
+Expert LoRA targeting
+You can optionally apply LoRA adapters directly to expert weights using lora_target_parameters:
lora_target_parameters:
+ - mlp.experts.gate_up_proj
+ - mlp.experts.down_proj
+ # - mlp.gate.weight # routerlora_dropout must be 0 when using lora_target_parameters.
Requirements
+-
+
- Requires (
adapter: loraandload_in_8bit: true) or (adapter: qloraandload_in_4bit: true)
+ - CUDA GPUs only (not tested with ROCm or other backends) +
- FSDP2 compatible for distributed training +
Limitations
+-
+
cpu_ram_efficient_loadinghangs / takes long time with FSDP2 + QLoRA.
+- Total model parameter count may display incorrectly (trainable param count is correct). +
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this. +
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks. +
- Model loading takes longer due to on-demand quantization, even on consecutive runs. +
- DeepSpeed has not been tested. +
Implementation details
+The quantization is applied by patching transformers to intercept weight loading. +When a 3D+ CUDA tensor with “expert” in its name is detected:
+-
+
- 4-bit mode: Uses bitsandbytes NF4 parametrization (configurable via
bnb_4bit_quant_type).
+ - 8-bit mode: Uses a custom row-wise int8 parametrization with bitsandbytes dequantization. +
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to +transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
+For full implementation details, see PR #3439.
+ + +