Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
a4a3b618e7 force torch to match when installing fa and deepspeed using uv 2026-03-04 10:00:08 -05:00
Wing Lian
b6b8db805a fix python version typo for building 3.11 (#3454) 2026-03-04 09:53:35 -05:00
Wing Lian
653f90be25 Add torch 2.10.0 to unit tests and use python 3.14 (#3450)
* Add torch 2.10.0 to unit tests and use python 3.14

* hold on python 3.14 checks due to mistral common

* add base option to matrix
2026-03-03 13:01:52 -05:00
NanoCode012
945c8aeb10 Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes (#3439)
* fix: saving clones state dict

* fix: apply fix for only CP mode

* fix: add dropout check when using lora target param

* fix: re-add patch from transformers PR #39866

* feat: add moe quant to test by ved

* fix: try match target param properly end with

* fix: clear cache per param quant

* fix: attempt on-load quantize experts instead of post-load

* fix: attempt disable async load

* chore: add log

* chore: adjust log

* fix: remove cuda alloc for moe and enable async load

* chore: remove leftover logs

* chore: add extra empty cache

* fix(doc): clarify support

* fix: handle fsdp2 for paramwrapper dtensor

* feat: attempt to quant experts in 8bit mode too

* feat: attempt to release bf16 experts from vram

* feat: upgrade cce

* fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init

* fix: remove unnecessary gc and empty cache

* Revert "fix: remove unnecessary gc and empty cache"

This reverts commit 1d54518990.

* fix: do not call full_tensor on non-dtensors

* fix: attempt to address fsdp2 with quant exp high loss

* fix: attempt lora quant experts wrong dim

* fix: ensure require_grad patch applied for lora 8bit

* fix: attempt lora 8bit fsdp2

* fix: attribute access on save for lora 8bit fsdp2

* fix: wrong weight attrib access

* chore(refactor): add config, re-arrange position of patches, clean
comments

* feat: add example docs

* chore: cherry pick trinity fixes from PR 3399

* chore: comments refactor; add guards

* fix: guard using wrong key

* fix: mamba save does not accept main process param

* fix: guard prevent double hook

* fix: move gc to upper scope

* chore: add comment on proxy forward patch

* fix: add comment to clarify

* feat: add test idempotency

* fix: AttributeError: `e_score_correction_bias` is not an nn.Parameter

* fix: AttributeError: 'NoneType' object has no attribute 'to'

* fix: update docs on cpu_ram_efficient_loading
2026-03-03 10:06:23 -05:00
NanoCode012
e672d37f33 fix: qwen3-next to use fla causal-conv1d to support packing (#3437
* fix: qwen3-next to use fla causal-conv1d to support packing

* fix: causal import and update doc for v5

* fix: hard fail for packing without fla
2026-03-03 09:26:46 -05:00
31 changed files with 1106 additions and 79 deletions

View File

@@ -51,6 +51,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -173,6 +181,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -54,13 +54,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20
steps:
@@ -149,13 +149,13 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 20
steps:
@@ -326,6 +326,12 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
@@ -371,7 +377,7 @@ jobs:
include:
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
]
},
{

View File

@@ -0,0 +1,77 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA
# - no target experts (1x48GB @ ~24GiB/GPU)
# - target experts (1x48GB @ ~34GiB/GPU)
axolotl train examples/glm4.7-flash/qlora.yaml
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
```
```bash
# LoRA
# - no target experts (1x48GB @ ~35GiB/GPU)
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
axolotl train examples/glm4.7-flash/lora.yaml
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
```
### Expert LoRA
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router, untested but should work, not normally targeted
```
## Limitations
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
- **lora_target_linear**: Incompatible for this model.
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
### TIPS
- For inference, the official Z.ai team recommends these default settings (most tasks):
- `temperature: 1.0`
- `top_p: 0.95`
- `max_new_tokens: 131072`
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -0,0 +1,65 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -0,0 +1,75 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
4. Run the finetuning example:
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about 45.62 GiB VRAM.
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -9,6 +9,8 @@ plugins:
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
@@ -25,7 +27,7 @@ sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_dropout: 0
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
@@ -34,12 +36,19 @@ lora_target_modules:
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:

View File

@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Run the finetuning example:
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM.
This config uses about 24.9 GiB VRAM (w/o CCE).
Let us know how it goes. Happy finetuning! 🚀
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,5 +1,4 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF

View File

@@ -63,3 +63,5 @@ docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]
flash-attn = [{ requirement = "torch", match-runtime = true }]
deepspeed = [{ requirement = "torch", match-runtime = true }]

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
)

View File

@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
}

View File

@@ -720,12 +720,16 @@ class AxolotlTrainer(
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
@@ -761,7 +765,11 @@ class AxolotlTrainer(
metadata={"format": "pt"},
)
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
```
## Usage
@@ -88,9 +88,9 @@ plugins:
- qwen2_vl
- qwen3
- qwen3_5
- qwen3_5_text
- qwen3_5_moe
- qwen3_5_moe_vl
- qwen3_5_vl
- qwen3_5_moe_text
- qwen3_moe
- qwen3_next
- qwen3_vl

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
)

View File

@@ -39,6 +39,8 @@ This works for any MoE model in transformers that uses a `SparseMoeBlock` class
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
return self
for param in model.parameters():
if isinstance(param, Params4bit):
if isinstance(param, Params4bit) and param.quant_state is not None:
param.quant_state._orig_to = param.quant_state.to
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

View File

@@ -172,7 +172,10 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
self.patch_manager.apply_post_model_build_patches(self.model)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
# Post-build model configuration
@@ -860,6 +863,10 @@ class ModelLoader:
# Make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if getattr(self.model, "_moe_experts_quantized", False):
# Parametrized expert tensors dequantize on access — would OOM.
skip_prepare_model_for_kbit_training = True
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]

View File

@@ -118,6 +118,7 @@ class PatchManager:
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
self._apply_moe_expert_quantization_patch()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
@@ -135,6 +136,10 @@ class PatchManager:
patch_prepare_context_parallel_inputs()
def apply_post_model_build_patches(self, model: PreTrainedModel):
"""Apply patches right after model build, before post-load setup."""
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)
@@ -170,9 +175,14 @@ class PatchManager:
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
from axolotl.monkeypatch.accelerate.fsdp2 import (
patch_accelerate_fsdp2,
patch_tied_keys_for_meta_device,
)
patch_accelerate_fsdp2()
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
patch_tied_keys_for_meta_device()
if self.cfg.rl:
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
@@ -352,15 +362,54 @@ class PatchManager:
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_dtype_attrs_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
apply_linear8bitlt_save_patch,
)
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
apply_init_dtype_attrs_patch()
if self.cfg.load_in_8bit:
apply_linear8bitlt_save_patch()
def _apply_moe_expert_quantization_patch(self):
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
if not self.cfg.quantize_moe_experts:
return
from axolotl.monkeypatch.moe_quant import (
patch_moe_quantization_on_load,
patch_peft_target_parameters_matching,
)
patch_moe_quantization_on_load(self.cfg)
patch_peft_target_parameters_matching()
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
"""Log quantization results and set model flag for downstream use."""
import torch
model._moe_experts_quantized = False
if self.cfg.quantize_moe_experts:
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
count = get_moe_quantized_count()
if count > 0:
import gc
model._moe_experts_quantized = True
LOG.info(
"Quantized %d MoE expert parameter(s) to %s during model loading",
count,
"4-bit" if self.cfg.load_in_4bit else "8-bit",
)
gc.collect()
torch.cuda.empty_cache()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:

View File

@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
**kwargs,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True):
)
elif self.is_fsdp2:
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
from torch.distributed.tensor import DTensor
state_dict = {}
sharded_state_dict = model.state_dict()
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
param = param.full_tensor()
if isinstance(param, DTensor):
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
state_dict[param_name] = param.cpu()
torch.distributed.barrier()
@@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True):
return state_dict
def patch_peft_param_wrapper_for_fsdp2():
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
delta_weight to the base weight W inside _LoraParameterProxy.forward().
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
This patch promotes the non-DTensor operand to match the DTensor's spec
using DTensor.from_local(), which is free for Replicate placement (just
metadata wrapping, no communication).
"""
from peft.tuners.lora.layer import _LoraParameterProxy
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
return
_original_forward = _LoraParameterProxy.forward
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
def _patched_forward(self, W):
from torch.distributed.tensor import DTensor
delta = self.delta_weight
w_is_dt = isinstance(W, DTensor)
d_is_dt = isinstance(delta, DTensor)
with torch.nn.utils.parametrize.cached():
if w_is_dt == d_is_dt:
return W + delta
if w_is_dt:
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
_LoraParameterProxy.forward = _patched_forward
_LoraParameterProxy._axolotl_fsdp2_patched = True
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from peft.tuners.lora.layer import ParamWrapper
from torch.distributed.fsdp import fully_shard
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
# The parent decoder layer's FSDP wrapper handles unsharding them.
# TODO: review if we even need to shard them separately in first place.
if isinstance(module, ParamWrapper):
return False
log_bias_dtype_mismatch = False
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
@@ -327,6 +377,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
is_peft_model = isinstance(model, PeftModel)
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
if is_peft_model:
from peft.tuners.lora.layer import ParamWrapper
if any(isinstance(m, ParamWrapper) for m in model.modules()):
patch_peft_param_wrapper_for_fsdp2()
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
@@ -376,6 +434,43 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
return model
def patch_tied_keys_for_meta_device():
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
grouped as "tied". Skipping them is safe since they have no real storage.
"""
from collections import defaultdict
from transformers import PreTrainedModel
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
param_pointers = defaultdict(list)
for param_name, param_value in self.state_dict().items():
if param_value.is_meta:
continue
param_pointers[param_value.data_ptr()].append(param_name)
tied_param_names = [
names
for names in param_pointers.values()
if len(names) > 1
and not any(name in self.all_tied_weights_keys.keys() for name in names)
and not all(name in missing_keys for name in names)
]
tied_weights_keys_by_pointers = {
param_name: group[0]
for group in tied_param_names
for param_name in group[1:]
}
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
_patched_adjust_tied_keys_with_tied_pointers
)
def patch_accelerate_fsdp2():
import accelerate

View File

@@ -1,9 +1,10 @@
"""
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
metadata through the FSDP2 shard/unshard cycle.
"""
import importlib
@@ -17,6 +18,8 @@ LOG = get_logger(__name__)
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -41,9 +44,20 @@ def apply_init_sharded_param_patch():
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
elif isinstance(param, bnb.nn.modules.Int8Params):
self.sharded_param = bnb.nn.modules.Int8Params(
data=sharded_param,
requires_grad=param.requires_grad,
has_fp16_weights=param.has_fp16_weights,
CB=None,
SCB=param.SCB,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(sharded_param),
requires_grad=param.requires_grad,
)"""
# Apply the replacement
if original_param_creation in original_source:
@@ -73,6 +87,7 @@ def apply_init_sharded_param_patch():
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param
apply_init_sharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
@@ -80,6 +95,8 @@ def apply_init_sharded_param_patch():
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
@@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch():
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
self._unsharded_param = bnb.nn.modules.Int8Params(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
has_fp16_weights=local_tensor.has_fp16_weights,
CB=unsharded_param,
SCB=local_tensor.SCB,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
@@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch():
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param
apply_init_unsharded_param_patch._axolotl_patched = True
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")
def apply_linear8bitlt_save_patch():
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
doesn't proxy custom attribute access to its _local_tensor. This patch
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
"""
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
return
import bitsandbytes as bnb
from torch.distributed.tensor import DTensor
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
weight = self._parameters["weight"]
unwrapped = False
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
self._parameters["weight"] = weight._local_tensor
unwrapped = True
try:
original_save(self, destination, prefix, keep_vars)
finally:
if unwrapped:
self._parameters["weight"] = weight
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
apply_linear8bitlt_save_patch._axolotl_patched = True
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
def apply_init_dtype_attrs_patch():
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
without extensions.
"""
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
return
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
def patched_init_dtype_attrs(self, mp_policy):
original_init_dtype_attrs(self, mp_policy)
# Skip casting non-float quantized params (uint8/int8) without FSDP2
# extensions — the parametrization chain handles dequantization.
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
local = self.sharded_param
if hasattr(local, "_local_tensor"):
local = local._local_tensor
if not hasattr(local, "fsdp_pre_all_gather"):
self.param_dtype = None
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
apply_init_dtype_attrs_patch._axolotl_patched = True
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")

View File

@@ -9,6 +9,11 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
except ImportError:
fla_causal_conv1d = None
def get_cu_seqlens(position_ids):
"""
@@ -137,6 +142,11 @@ def patch_qwen3_next_gateddelta_layer():
and cache_position is not None
)
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
cu_seqlens = None
if not use_precomputed_states and position_ids is not None:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
@@ -151,12 +161,11 @@ def patch_qwen3_next_gateddelta_layer():
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
# Inference single-token path: causal_conv1d_update expects [B, D, T]
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
@@ -164,24 +173,41 @@ def patch_qwen3_next_gateddelta_layer():
self.conv1d.bias,
self.activation,
)
mixed_qkv = mixed_qkv.transpose(1, 2)
else:
if cache_params is not None:
# Cache state expects [B, D, T] for the inference update path
mixed_qkv_t = mixed_qkv.transpose(1, 2)
conv_state = F.pad(
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
mixed_qkv_t,
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
)
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
if fla_causal_conv1d is not None:
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
cu_seqlens=cu_seqlens,
)
else:
# PyTorch fallback (no cu_seqlens support)
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
raise RuntimeError(
"Packed sequences require fla.modules.convolution.causal_conv1d "
"(cu_seqlens support). Install flash-linear-attention or disable packing."
)
LOG.warning_once(
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = mixed_qkv.transpose(1, 2)
# mixed_qkv is [B, T, D] in all paths
query, key, value = torch.split(
mixed_qkv,
[
@@ -203,7 +229,6 @@ def patch_qwen3_next_gateddelta_layer():
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,

View File

@@ -0,0 +1,188 @@
"""
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
skips (only targets nn.Linear). This module patches weight loading to quantize them
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
"""
import bitsandbytes as bnb
import torch
import torch.nn.utils.parametrize as P
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# Module-level state for the loading-time quantization patch.
_moe_load_state = {
"count": 0,
"mode": "4bit",
"quant_type": "nf4",
"compress_statistics": True,
"patched": False,
}
class Bnb8bitParametrization(torch.nn.Module):
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
def __init__(self, row_stats: torch.Tensor):
super().__init__()
self.register_buffer("row_stats", row_stats)
@torch.no_grad()
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
orig_shape = quantized_param.shape
if quantized_param.ndim > 2:
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
return result.reshape(orig_shape)
def _enable_parametrization_cache(module, inputs):
P._cache_enabled += 1
def _disable_parametrization_cache(module, inputs, output):
P._cache_enabled -= 1
if not P._cache_enabled:
P._cache = {}
def replace_parameter_8bit(module, param_name):
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
original_param = getattr(module, param_name)
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
original_param.data.to(torch.float16)
)
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
del original_param
P.register_parametrization(
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
)
# Cache dequantized values during forward to avoid redundant dequantization.
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
module.register_forward_pre_hook(_enable_parametrization_cache)
module.register_forward_hook(_disable_parametrization_cache)
module._axolotl_8bit_hooks_registered = True
def patch_moe_quantization_on_load(cfg):
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
"""
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
_moe_load_state["mode"] = mode
_moe_load_state["count"] = 0
if _moe_load_state["patched"]:
LOG.debug("MoE loading-time quantization patch already active")
return
import transformers.core_model_loading
import transformers.modeling_utils
if mode == "4bit":
from bitsandbytes.nn.parametrize import replace_parameter_4bit
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
if compress_statistics is None:
compress_statistics = True
_moe_load_state["quant_type"] = quant_type
_moe_load_state["compress_statistics"] = compress_statistics
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
# size for all params, defeating our on-load quantization VRAM savings.
def _noop_warmup(*args, **kwargs):
pass
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
original_set_param = transformers.core_model_loading.set_param_for_module
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
if "expert" not in target_name.lower():
LOG.debug(
"Skipping non-expert 3D param: %s (shape=%s)",
target_name,
list(param_value.shape),
)
return
if _moe_load_state["mode"] == "4bit":
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
else:
replace_parameter_8bit(mod, pname)
_moe_load_state["count"] += 1
# Release the bf16 tensor so CUDA memory is freed immediately.
param_value.data = torch.empty(0, device="cpu")
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True
def get_moe_quantized_count():
"""Return the number of expert parameters quantized during loading."""
return _moe_load_state["count"]
def patch_peft_target_parameters_matching():
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
return
from peft.tuners.tuners_utils import BaseTuner
original_inject = BaseTuner._inject_parameters
def _patched_inject_parameters(
self, peft_config, model, adapter_name, low_cpu_mem_usage
):
# Patch target_parameters to use full paths for parametrized modules
original_targets = list(peft_config.target_parameters)
expanded = set(original_targets)
for module_name, module in model.named_modules():
if not hasattr(module, "parametrizations"):
continue
for target in original_targets:
mod_path, _, param_name = target.rpartition(".")
if (
module_name == mod_path or module_name.endswith("." + mod_path)
) and hasattr(module, param_name):
expanded.add(f"{module_name}.{param_name}")
peft_config.target_parameters = sorted(expanded)
try:
return original_inject(
self, peft_config, model, adapter_name, low_cpu_mem_usage
)
finally:
peft_config.target_parameters = original_targets
BaseTuner._inject_parameters = _patched_inject_parameters
patch_peft_target_parameters_matching._axolotl_patched = True
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")

View File

@@ -48,9 +48,9 @@ class ChatTemplatePrompter(Prompter):
):
# check if message_property_mappings is None or empty dict
if message_property_mappings is None or (not message_property_mappings):
default_message_property_mappings_keys = ["role", "content", "tool"]
message_property_mappings = {
prop: prop for prop in default_message_property_mappings_keys
"role": "role",
"content": "content",
}
if template_thinking_key and field_thinking:
message_property_mappings[template_thinking_key] = field_thinking

View File

@@ -629,6 +629,17 @@ class AxolotlInputConfig(
},
)
quantize_moe_experts: bool = Field(
default=False,
json_schema_extra={
"description": "Quantize MoE expert weights on load to reduce VRAM. "
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
"Requires CUDA (not compatible with ROCm or other backends). "
"Note: total parameter count may be reported incorrectly when enabled "
"(trainable param count is correct)."
},
)
scaling_softmax: bool | None = Field(
default=None,
json_schema_extra={
@@ -1289,6 +1300,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return data
@model_validator(mode="before")
@classmethod
def check_quantize_moe_experts(cls, data):
if data.get("quantize_moe_experts"):
if data.get("adapter") not in ("lora", "qlora"):
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
raise ValueError(
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
)
if (
data.get("capabilities")
and data["capabilities"].get("compute_capability")
and not data["capabilities"]["compute_capability"].startswith("sm_")
):
raise ValueError(
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):

View File

@@ -209,6 +209,19 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="after")
def validate_lora_target_parameters_dropout(self):
if (
self.lora_target_parameters
and self.lora_dropout
and self.lora_dropout != 0.0
):
raise ValueError(
"lora_dropout must be 0 when lora_target_parameters is set. "
"PEFT's ParamWrapper does not support lora_dropout != 0."
)
return self
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -0,0 +1,142 @@
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture()
def gpu_caps():
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
@pytest.fixture()
def env_caps():
return {"torch_version": "2.7.0"}
class TestQuantizeMoeExpertsValidation:
"""Test suite for quantize_moe_experts config validator."""
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts without adapter should fail."""
cfg = (
DictDefault(
quantize_moe_experts=True,
)
| min_base_cfg
)
with pytest.raises(ValueError, match="requires adapter"):
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="lora",
)
| min_base_cfg
)
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts with qlora + 4bit should pass."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="qlora",
load_in_4bit=True,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is True
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts with lora + 8bit should pass."""
cfg = (
DictDefault(
quantize_moe_experts=True,
adapter="lora",
load_in_8bit=True,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is True
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts=false should not check adapter/quantization."""
cfg = (
DictDefault(
quantize_moe_experts=False,
)
| min_base_cfg
)
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is False
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
"""quantize_moe_experts should default to false."""
cfg = DictDefault({}) | min_base_cfg
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
assert result["quantize_moe_experts"] is False
class TestLoraTargetParametersDropout:
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
def test_rejects_nonzero_dropout(self, min_base_cfg):
"""lora_dropout > 0 with lora_target_parameters should fail."""
cfg = (
DictDefault(
adapter="lora",
lora_target_parameters=["mlp.experts.gate_up_proj"],
lora_dropout=0.1,
load_in_8bit=True,
)
| min_base_cfg
)
with pytest.raises(ValueError, match="lora_dropout must be 0"):
validate_config(cfg)
def test_zero_dropout_passes(self, min_base_cfg):
"""lora_dropout=0 with lora_target_parameters should pass."""
cfg = (
DictDefault(
adapter="lora",
lora_target_parameters=["mlp.experts.gate_up_proj"],
lora_dropout=0.0,
load_in_8bit=True,
)
| min_base_cfg
)
result = validate_config(cfg)
assert result["lora_dropout"] == 0.0
class TestPeftPatchIdempotency:
"""Test that patch_peft_target_parameters_matching is idempotent."""
def test_double_call_does_not_stack_wrappers(self):
"""Calling patch twice should not double-wrap _inject_parameters."""
from peft.tuners.tuners_utils import BaseTuner
from axolotl.monkeypatch.moe_quant import (
patch_peft_target_parameters_matching,
)
original = BaseTuner._inject_parameters
try:
patch_peft_target_parameters_matching()
first_patched = BaseTuner._inject_parameters
patch_peft_target_parameters_matching()
second_patched = BaseTuner._inject_parameters
# Should be same function, not double-wrapped
assert first_patched is second_patched
finally:
BaseTuner._inject_parameters = original
patch_peft_target_parameters_matching._axolotl_patched = False