Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes (#3439)
* fix: saving clones state dict
* fix: apply fix for only CP mode
* fix: add dropout check when using lora target param
* fix: re-add patch from transformers PR #39866
* feat: add moe quant to test by ved
* fix: try match target param properly end with
* fix: clear cache per param quant
* fix: attempt on-load quantize experts instead of post-load
* fix: attempt disable async load
* chore: add log
* chore: adjust log
* fix: remove cuda alloc for moe and enable async load
* chore: remove leftover logs
* chore: add extra empty cache
* fix(doc): clarify support
* fix: handle fsdp2 for paramwrapper dtensor
* feat: attempt to quant experts in 8bit mode too
* feat: attempt to release bf16 experts from vram
* feat: upgrade cce
* fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init
* fix: remove unnecessary gc and empty cache
* Revert "fix: remove unnecessary gc and empty cache"
This reverts commit 1d54518990.
* fix: do not call full_tensor on non-dtensors
* fix: attempt to address fsdp2 with quant exp high loss
* fix: attempt lora quant experts wrong dim
* fix: ensure require_grad patch applied for lora 8bit
* fix: attempt lora 8bit fsdp2
* fix: attribute access on save for lora 8bit fsdp2
* fix: wrong weight attrib access
* chore(refactor): add config, re-arrange position of patches, clean
comments
* feat: add example docs
* chore: cherry pick trinity fixes from PR 3399
* chore: comments refactor; add guards
* fix: guard using wrong key
* fix: mamba save does not accept main process param
* fix: guard prevent double hook
* fix: move gc to upper scope
* chore: add comment on proxy forward patch
* fix: add comment to clarify
* feat: add test idempotency
* fix: AttributeError: `e_score_correction_bias` is not an nn.Parameter
* fix: AttributeError: 'NoneType' object has no attribute 'to'
* fix: update docs on cpu_ram_efficient_loading
This commit is contained in:
@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -720,12 +720,16 @@ class AxolotlTrainer(
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
# fix for Context Parallel save
|
||||
if state_dict is None:
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
if state_dict is not None:
|
||||
# fix for Context Parallel save: CP eval invalidates tensor storage
|
||||
# pointers, so clone to CPU to get fresh valid storage for safetensors
|
||||
if (
|
||||
state_dict is not None
|
||||
and self.axolotl_cfg
|
||||
and self.axolotl_cfg.context_parallel_size
|
||||
and self.axolotl_cfg.context_parallel_size > 1
|
||||
):
|
||||
state_dict = {
|
||||
k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
@@ -761,7 +765,11 @@ class AxolotlTrainer(
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -88,9 +88,9 @@ plugins:
|
||||
- qwen2_vl
|
||||
- qwen3
|
||||
- qwen3_5
|
||||
- qwen3_5_text
|
||||
- qwen3_5_moe
|
||||
- qwen3_5_moe_vl
|
||||
- qwen3_5_vl
|
||||
- qwen3_5_moe_text
|
||||
- qwen3_moe
|
||||
- qwen3_next
|
||||
- qwen3_vl
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ This works for any MoE model in transformers that uses a `SparseMoeBlock` class
|
||||
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
||||
|
||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||
|
||||
## Note on MegaBlocks
|
||||
|
||||
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.
|
||||
|
||||
@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
||||
return self
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit):
|
||||
if isinstance(param, Params4bit) and param.quant_state is not None:
|
||||
param.quant_state._orig_to = param.quant_state.to
|
||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||
|
||||
|
||||
@@ -172,7 +172,10 @@ class ModelLoader:
|
||||
# Build the model
|
||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||
|
||||
skip_move_to_device = self._build_model()
|
||||
self.patch_manager.apply_post_model_build_patches(self.model)
|
||||
|
||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||
|
||||
# Post-build model configuration
|
||||
@@ -860,6 +863,10 @@ class ModelLoader:
|
||||
# Make sure everything is in the same dtype
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if getattr(self.model, "_moe_experts_quantized", False):
|
||||
# Parametrized expert tensors dequantize on access — would OOM.
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if (
|
||||
not skip_prepare_model_for_kbit_training
|
||||
and self.cfg.adapter in ["lora", "qlora"]
|
||||
|
||||
@@ -118,6 +118,7 @@ class PatchManager:
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
self._apply_moe_expert_quantization_patch()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
@@ -135,6 +136,10 @@ class PatchManager:
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches right after model build, before post-load setup."""
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
@@ -170,9 +175,14 @@ class PatchManager:
|
||||
|
||||
patch_parallelism_config()
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import (
|
||||
patch_accelerate_fsdp2,
|
||||
patch_tied_keys_for_meta_device,
|
||||
)
|
||||
|
||||
patch_accelerate_fsdp2()
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
patch_tied_keys_for_meta_device()
|
||||
if self.cfg.rl:
|
||||
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
||||
|
||||
@@ -352,15 +362,54 @@ class PatchManager:
|
||||
if (
|
||||
self.cfg.fsdp_config
|
||||
and str(self.cfg.fsdp_version) == "2"
|
||||
and self.cfg.adapter == "qlora"
|
||||
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
|
||||
):
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_init_dtype_attrs_patch,
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
apply_linear8bitlt_save_patch,
|
||||
)
|
||||
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
apply_init_dtype_attrs_patch()
|
||||
if self.cfg.load_in_8bit:
|
||||
apply_linear8bitlt_save_patch()
|
||||
|
||||
def _apply_moe_expert_quantization_patch(self):
|
||||
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
||||
if not self.cfg.quantize_moe_experts:
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
patch_moe_quantization_on_load,
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
patch_peft_target_parameters_matching()
|
||||
|
||||
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
||||
"""Log quantization results and set model flag for downstream use."""
|
||||
import torch
|
||||
|
||||
model._moe_experts_quantized = False
|
||||
if self.cfg.quantize_moe_experts:
|
||||
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
|
||||
|
||||
count = get_moe_quantized_count()
|
||||
if count > 0:
|
||||
import gc
|
||||
|
||||
model._moe_experts_quantized = True
|
||||
LOG.info(
|
||||
"Quantized %d MoE expert parameter(s) to %s during model loading",
|
||||
count,
|
||||
"4-bit" if self.cfg.load_in_4bit else "8-bit",
|
||||
)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
|
||||
@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
state_dict: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.state_dict()
|
||||
|
||||
@@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True):
|
||||
)
|
||||
elif self.is_fsdp2:
|
||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
|
||||
param = param.full_tensor()
|
||||
if isinstance(param, DTensor):
|
||||
param = param.full_tensor()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
state_dict[param_name] = param.cpu()
|
||||
torch.distributed.barrier()
|
||||
@@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True):
|
||||
return state_dict
|
||||
|
||||
|
||||
def patch_peft_param_wrapper_for_fsdp2():
|
||||
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
|
||||
|
||||
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
|
||||
delta_weight to the base weight W inside _LoraParameterProxy.forward().
|
||||
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
|
||||
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
|
||||
|
||||
This patch promotes the non-DTensor operand to match the DTensor's spec
|
||||
using DTensor.from_local(), which is free for Replicate placement (just
|
||||
metadata wrapping, no communication).
|
||||
"""
|
||||
from peft.tuners.lora.layer import _LoraParameterProxy
|
||||
|
||||
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
|
||||
return
|
||||
|
||||
_original_forward = _LoraParameterProxy.forward
|
||||
|
||||
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
|
||||
def _patched_forward(self, W):
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
delta = self.delta_weight
|
||||
w_is_dt = isinstance(W, DTensor)
|
||||
d_is_dt = isinstance(delta, DTensor)
|
||||
|
||||
with torch.nn.utils.parametrize.cached():
|
||||
if w_is_dt == d_is_dt:
|
||||
return W + delta
|
||||
if w_is_dt:
|
||||
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
|
||||
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
|
||||
|
||||
_LoraParameterProxy.forward = _patched_forward
|
||||
_LoraParameterProxy._axolotl_fsdp2_patched = True
|
||||
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
|
||||
|
||||
|
||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
"""Helper function to process LoRA modules for FSDP2."""
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
|
||||
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
|
||||
# The parent decoder layer's FSDP wrapper handles unsharding them.
|
||||
# TODO: review if we even need to shard them separately in first place.
|
||||
if isinstance(module, ParamWrapper):
|
||||
return False
|
||||
|
||||
log_bias_dtype_mismatch = False
|
||||
|
||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||
@@ -327,6 +377,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
is_peft_model = isinstance(model, PeftModel)
|
||||
|
||||
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
|
||||
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
|
||||
if is_peft_model:
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
|
||||
if any(isinstance(m, ParamWrapper) for m in model.modules()):
|
||||
patch_peft_param_wrapper_for_fsdp2()
|
||||
|
||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||
log_bias_dtype_mismatch = False
|
||||
if auto_wrap_policy is not None:
|
||||
@@ -376,6 +434,43 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def patch_tied_keys_for_meta_device():
|
||||
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
|
||||
|
||||
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
|
||||
grouped as "tied". Skipping them is safe since they have no real storage.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
|
||||
param_pointers = defaultdict(list)
|
||||
for param_name, param_value in self.state_dict().items():
|
||||
if param_value.is_meta:
|
||||
continue
|
||||
param_pointers[param_value.data_ptr()].append(param_name)
|
||||
|
||||
tied_param_names = [
|
||||
names
|
||||
for names in param_pointers.values()
|
||||
if len(names) > 1
|
||||
and not any(name in self.all_tied_weights_keys.keys() for name in names)
|
||||
and not all(name in missing_keys for name in names)
|
||||
]
|
||||
|
||||
tied_weights_keys_by_pointers = {
|
||||
param_name: group[0]
|
||||
for group in tied_param_names
|
||||
for param_name in group[1:]
|
||||
}
|
||||
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
|
||||
|
||||
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
|
||||
_patched_adjust_tied_keys_with_tied_pointers
|
||||
)
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""
|
||||
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
||||
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
|
||||
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||
|
||||
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
||||
Params4bit parameters.
|
||||
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
|
||||
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
|
||||
metadata through the FSDP2 shard/unshard cycle.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -17,6 +18,8 @@ LOG = get_logger(__name__)
|
||||
|
||||
def apply_init_sharded_param_patch():
|
||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
|
||||
return
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
@@ -41,9 +44,20 @@ def apply_init_sharded_param_patch():
|
||||
bnb_quantized=param.bnb_quantized,
|
||||
)
|
||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||
elif isinstance(param, bnb.nn.modules.Int8Params):
|
||||
self.sharded_param = bnb.nn.modules.Int8Params(
|
||||
data=sharded_param,
|
||||
requires_grad=param.requires_grad,
|
||||
has_fp16_weights=param.has_fp16_weights,
|
||||
CB=None,
|
||||
SCB=param.SCB,
|
||||
)
|
||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||
else:
|
||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||
self.sharded_param = nn.Parameter(
|
||||
self.to_sharded_dtensor(sharded_param),
|
||||
requires_grad=param.requires_grad,
|
||||
)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_param_creation in original_source:
|
||||
@@ -73,6 +87,7 @@ def apply_init_sharded_param_patch():
|
||||
|
||||
# Replace the method
|
||||
FSDPParam._init_sharded_param = patched_init_sharded_param
|
||||
apply_init_sharded_param_patch._axolotl_patched = True
|
||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||
@@ -80,6 +95,8 @@ def apply_init_sharded_param_patch():
|
||||
|
||||
def apply_init_unsharded_param_patch():
|
||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
|
||||
return
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
@@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch():
|
||||
module=local_tensor.module,
|
||||
bnb_quantized=local_tensor.bnb_quantized,
|
||||
)
|
||||
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
|
||||
self._unsharded_param = bnb.nn.modules.Int8Params(
|
||||
data=unsharded_param,
|
||||
requires_grad=self.sharded_param.requires_grad,
|
||||
has_fp16_weights=local_tensor.has_fp16_weights,
|
||||
CB=unsharded_param,
|
||||
SCB=local_tensor.SCB,
|
||||
)
|
||||
else:
|
||||
self._unsharded_param = nn.Parameter(
|
||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||
@@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch():
|
||||
|
||||
# Replace the method
|
||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
||||
apply_init_unsharded_param_patch._axolotl_patched = True
|
||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for patching")
|
||||
|
||||
|
||||
def apply_linear8bitlt_save_patch():
|
||||
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
|
||||
|
||||
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
|
||||
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
|
||||
doesn't proxy custom attribute access to its _local_tensor. This patch
|
||||
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
|
||||
"""
|
||||
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
|
||||
return
|
||||
import bitsandbytes as bnb
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
|
||||
|
||||
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
|
||||
weight = self._parameters["weight"]
|
||||
unwrapped = False
|
||||
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
|
||||
self._parameters["weight"] = weight._local_tensor
|
||||
unwrapped = True
|
||||
try:
|
||||
original_save(self, destination, prefix, keep_vars)
|
||||
finally:
|
||||
if unwrapped:
|
||||
self._parameters["weight"] = weight
|
||||
|
||||
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
|
||||
apply_linear8bitlt_save_patch._axolotl_patched = True
|
||||
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
|
||||
|
||||
|
||||
def apply_init_dtype_attrs_patch():
|
||||
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
|
||||
|
||||
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
|
||||
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
|
||||
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
|
||||
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
|
||||
|
||||
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
|
||||
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
|
||||
without extensions.
|
||||
"""
|
||||
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
|
||||
return
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
|
||||
|
||||
def patched_init_dtype_attrs(self, mp_policy):
|
||||
original_init_dtype_attrs(self, mp_policy)
|
||||
# Skip casting non-float quantized params (uint8/int8) without FSDP2
|
||||
# extensions — the parametrization chain handles dequantization.
|
||||
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
|
||||
local = self.sharded_param
|
||||
if hasattr(local, "_local_tensor"):
|
||||
local = local._local_tensor
|
||||
if not hasattr(local, "fsdp_pre_all_gather"):
|
||||
self.param_dtype = None
|
||||
|
||||
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
|
||||
apply_init_dtype_attrs_patch._axolotl_patched = True
|
||||
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")
|
||||
|
||||
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||
|
||||
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
||||
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
||||
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
||||
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
||||
"""
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn.utils.parametrize as P
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
# Module-level state for the loading-time quantization patch.
|
||||
_moe_load_state = {
|
||||
"count": 0,
|
||||
"mode": "4bit",
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": True,
|
||||
"patched": False,
|
||||
}
|
||||
|
||||
|
||||
class Bnb8bitParametrization(torch.nn.Module):
|
||||
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
||||
|
||||
def __init__(self, row_stats: torch.Tensor):
|
||||
super().__init__()
|
||||
self.register_buffer("row_stats", row_stats)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
||||
orig_shape = quantized_param.shape
|
||||
if quantized_param.ndim > 2:
|
||||
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
||||
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
|
||||
return result.reshape(orig_shape)
|
||||
|
||||
|
||||
def _enable_parametrization_cache(module, inputs):
|
||||
P._cache_enabled += 1
|
||||
|
||||
|
||||
def _disable_parametrization_cache(module, inputs, output):
|
||||
P._cache_enabled -= 1
|
||||
if not P._cache_enabled:
|
||||
P._cache = {}
|
||||
|
||||
|
||||
def replace_parameter_8bit(module, param_name):
|
||||
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
|
||||
original_param = getattr(module, param_name)
|
||||
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
|
||||
original_param.data.to(torch.float16)
|
||||
)
|
||||
|
||||
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
|
||||
del original_param
|
||||
|
||||
P.register_parametrization(
|
||||
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
|
||||
)
|
||||
|
||||
# Cache dequantized values during forward to avoid redundant dequantization.
|
||||
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
|
||||
module.register_forward_pre_hook(_enable_parametrization_cache)
|
||||
module.register_forward_hook(_disable_parametrization_cache)
|
||||
module._axolotl_8bit_hooks_registered = True
|
||||
|
||||
|
||||
def patch_moe_quantization_on_load(cfg):
|
||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
||||
|
||||
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
||||
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
||||
"""
|
||||
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
||||
_moe_load_state["mode"] = mode
|
||||
_moe_load_state["count"] = 0
|
||||
|
||||
if _moe_load_state["patched"]:
|
||||
LOG.debug("MoE loading-time quantization patch already active")
|
||||
return
|
||||
|
||||
import transformers.core_model_loading
|
||||
import transformers.modeling_utils
|
||||
|
||||
if mode == "4bit":
|
||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||
|
||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||
if compress_statistics is None:
|
||||
compress_statistics = True
|
||||
|
||||
_moe_load_state["quant_type"] = quant_type
|
||||
_moe_load_state["compress_statistics"] = compress_statistics
|
||||
|
||||
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
|
||||
# size for all params, defeating our on-load quantization VRAM savings.
|
||||
def _noop_warmup(*args, **kwargs):
|
||||
pass
|
||||
|
||||
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||
|
||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||
|
||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
if "expert" not in target_name.lower():
|
||||
LOG.debug(
|
||||
"Skipping non-expert 3D param: %s (shape=%s)",
|
||||
target_name,
|
||||
list(param_value.shape),
|
||||
)
|
||||
return
|
||||
|
||||
if _moe_load_state["mode"] == "4bit":
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
else:
|
||||
replace_parameter_8bit(mod, pname)
|
||||
_moe_load_state["count"] += 1
|
||||
|
||||
# Release the bf16 tensor so CUDA memory is freed immediately.
|
||||
param_value.data = torch.empty(0, device="cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||
_moe_load_state["patched"] = True
|
||||
|
||||
|
||||
def get_moe_quantized_count():
|
||||
"""Return the number of expert parameters quantized during loading."""
|
||||
return _moe_load_state["count"]
|
||||
|
||||
|
||||
def patch_peft_target_parameters_matching():
|
||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||
return
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
original_inject = BaseTuner._inject_parameters
|
||||
|
||||
def _patched_inject_parameters(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
):
|
||||
# Patch target_parameters to use full paths for parametrized modules
|
||||
original_targets = list(peft_config.target_parameters)
|
||||
expanded = set(original_targets)
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
if not hasattr(module, "parametrizations"):
|
||||
continue
|
||||
for target in original_targets:
|
||||
mod_path, _, param_name = target.rpartition(".")
|
||||
if (
|
||||
module_name == mod_path or module_name.endswith("." + mod_path)
|
||||
) and hasattr(module, param_name):
|
||||
expanded.add(f"{module_name}.{param_name}")
|
||||
|
||||
peft_config.target_parameters = sorted(expanded)
|
||||
try:
|
||||
return original_inject(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
)
|
||||
finally:
|
||||
peft_config.target_parameters = original_targets
|
||||
|
||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||
@@ -629,6 +629,17 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
quantize_moe_experts: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Quantize MoE expert weights on load to reduce VRAM. "
|
||||
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
|
||||
"Requires CUDA (not compatible with ROCm or other backends). "
|
||||
"Note: total parameter count may be reported incorrectly when enabled "
|
||||
"(trainable param count is correct)."
|
||||
},
|
||||
)
|
||||
|
||||
scaling_softmax: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -1289,6 +1300,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_quantize_moe_experts(cls, data):
|
||||
if data.get("quantize_moe_experts"):
|
||||
if data.get("adapter") not in ("lora", "qlora"):
|
||||
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
|
||||
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
|
||||
raise ValueError(
|
||||
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
|
||||
)
|
||||
if (
|
||||
data.get("capabilities")
|
||||
and data["capabilities"].get("compute_capability")
|
||||
and not data["capabilities"]["compute_capability"].startswith("sm_")
|
||||
):
|
||||
raise ValueError(
|
||||
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_auto_enable_lora_kernels(cls, data):
|
||||
|
||||
@@ -209,6 +209,19 @@ class LoraConfig(BaseModel):
|
||||
data["lora_dropout"] = 0.0
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_lora_target_parameters_dropout(self):
|
||||
if (
|
||||
self.lora_target_parameters
|
||||
and self.lora_dropout
|
||||
and self.lora_dropout != 0.0
|
||||
):
|
||||
raise ValueError(
|
||||
"lora_dropout must be 0 when lora_target_parameters is set. "
|
||||
"PEFT's ParamWrapper does not support lora_dropout != 0."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
Reference in New Issue
Block a user