* fix: add validation for lora target linear with quantize experts * chore: fix lint * chore: comment * fix: missing link on readme
68 lines
2.8 KiB
Plaintext
68 lines
2.8 KiB
Plaintext
---
|
|
title: "MoE Expert Quantization"
|
|
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
|
|
---
|
|
|
|
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
|
|
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
|
|
weights being loaded in full bf16 precision and causing massive VRAM usage.
|
|
|
|
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
|
|
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
|
|
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
|
|
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
|
|
|
|
## Usage
|
|
|
|
Enable expert quantization in your Axolotl config:
|
|
|
|
```yaml
|
|
quantize_moe_experts: true
|
|
```
|
|
|
|
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
|
|
|
|
### Expert LoRA targeting
|
|
|
|
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
|
|
|
|
```yaml
|
|
lora_target_parameters:
|
|
- mlp.experts.gate_up_proj
|
|
- mlp.experts.down_proj
|
|
# - mlp.gate.weight # router
|
|
```
|
|
|
|
::: {.callout-note}
|
|
`lora_dropout` must be `0` when using `lora_target_parameters`.
|
|
:::
|
|
|
|
## Requirements
|
|
|
|
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
|
|
- CUDA GPUs only (not tested with ROCm or other backends)
|
|
- FSDP2 compatible for distributed training
|
|
|
|
## Limitations
|
|
|
|
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
|
|
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
|
|
- Total model parameter count may display incorrectly (trainable param count is correct).
|
|
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
|
|
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
|
|
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
|
|
- DeepSpeed has not been tested.
|
|
|
|
## Implementation details
|
|
|
|
The quantization is applied by patching transformers to intercept weight loading.
|
|
When a 3D+ CUDA tensor with "expert" in its name is detected:
|
|
|
|
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
|
|
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
|
|
|
|
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
|
|
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
|
|
|
|
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).
|