Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4a3b618e7 | ||
|
|
b6b8db805a | ||
|
|
653f90be25 | ||
|
|
945c8aeb10 | ||
|
|
e672d37f33 |
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -51,6 +51,14 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -173,6 +181,14 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
|
||||
36
.github/workflows/tests.yml
vendored
36
.github/workflows/tests.yml
vendored
@@ -54,13 +54,13 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11", "3.12"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
exclude:
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.8.0"
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.9.0"
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -149,13 +149,13 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11", "3.12"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
exclude:
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.8.0"
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.9.0"
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -326,6 +326,12 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
@@ -371,7 +377,7 @@ jobs:
|
||||
include:
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
77
examples/glm4.7-flash/README.md
Normal file
77
examples/glm4.7-flash/README.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
||||
|
||||
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# QLoRA
|
||||
# - no target experts (1x48GB @ ~24GiB/GPU)
|
||||
# - target experts (1x48GB @ ~34GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/qlora.yaml
|
||||
|
||||
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
# LoRA
|
||||
# - no target experts (1x48GB @ ~35GiB/GPU)
|
||||
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/lora.yaml
|
||||
|
||||
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
||||
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
|
||||
```
|
||||
|
||||
### Expert LoRA
|
||||
|
||||
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
|
||||
|
||||
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router, untested but should work, not normally targeted
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
|
||||
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
|
||||
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
|
||||
- **lora_target_linear**: Incompatible for this model.
|
||||
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
||||
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
||||
- `temperature: 1.0`
|
||||
- `top_p: 0.95`
|
||||
- `max_new_tokens: 131072`
|
||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
||||
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
65
examples/glm4.7-flash/lora.yaml
Normal file
65
examples/glm4.7-flash/lora.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
75
examples/glm4.7-flash/lora_fsdp.yaml
Normal file
75
examples/glm4.7-flash/lora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
65
examples/glm4.7-flash/qlora.yaml
Normal file
65
examples/glm4.7-flash/qlora.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
75
examples/glm4.7-flash/qlora_fsdp.yaml
Normal file
75
examples/glm4.7-flash/qlora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ plugins:
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
@@ -25,7 +27,7 @@ sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
@@ -34,12 +36,19 @@ lora_target_modules:
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
|
||||
@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Run the finetuning example:
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24.9 GiB VRAM.
|
||||
This config uses about 24.9 GiB VRAM (w/o CCE).
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
base_model: arcee-ai/Trinity-Nano-Preview
|
||||
trust_remote_code: true
|
||||
revision_of_model: 2ee94b0
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
|
||||
@@ -63,3 +63,5 @@ docstring-code-format = false
|
||||
|
||||
[tool.uv.extra-build-dependencies]
|
||||
axolotl = ["huggingface_hub"]
|
||||
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
||||
deepspeed = [{ requirement = "torch", match-runtime = true }]
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
|
||||
)
|
||||
|
||||
@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -720,12 +720,16 @@ class AxolotlTrainer(
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
# fix for Context Parallel save
|
||||
if state_dict is None:
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
if state_dict is not None:
|
||||
# fix for Context Parallel save: CP eval invalidates tensor storage
|
||||
# pointers, so clone to CPU to get fresh valid storage for safetensors
|
||||
if (
|
||||
state_dict is not None
|
||||
and self.axolotl_cfg
|
||||
and self.axolotl_cfg.context_parallel_size
|
||||
and self.axolotl_cfg.context_parallel_size > 1
|
||||
):
|
||||
state_dict = {
|
||||
k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
@@ -761,7 +765,11 @@ class AxolotlTrainer(
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -88,9 +88,9 @@ plugins:
|
||||
- qwen2_vl
|
||||
- qwen3
|
||||
- qwen3_5
|
||||
- qwen3_5_text
|
||||
- qwen3_5_moe
|
||||
- qwen3_5_moe_vl
|
||||
- qwen3_5_vl
|
||||
- qwen3_5_moe_text
|
||||
- qwen3_moe
|
||||
- qwen3_next
|
||||
- qwen3_vl
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ This works for any MoE model in transformers that uses a `SparseMoeBlock` class
|
||||
|
||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
||||
|
||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||
|
||||
## Note on MegaBlocks
|
||||
|
||||
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.
|
||||
|
||||
@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
||||
return self
|
||||
|
||||
for param in model.parameters():
|
||||
if isinstance(param, Params4bit):
|
||||
if isinstance(param, Params4bit) and param.quant_state is not None:
|
||||
param.quant_state._orig_to = param.quant_state.to
|
||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||
|
||||
|
||||
@@ -172,7 +172,10 @@ class ModelLoader:
|
||||
# Build the model
|
||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||
|
||||
skip_move_to_device = self._build_model()
|
||||
self.patch_manager.apply_post_model_build_patches(self.model)
|
||||
|
||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||
|
||||
# Post-build model configuration
|
||||
@@ -860,6 +863,10 @@ class ModelLoader:
|
||||
# Make sure everything is in the same dtype
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if getattr(self.model, "_moe_experts_quantized", False):
|
||||
# Parametrized expert tensors dequantize on access — would OOM.
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if (
|
||||
not skip_prepare_model_for_kbit_training
|
||||
and self.cfg.adapter in ["lora", "qlora"]
|
||||
|
||||
@@ -118,6 +118,7 @@ class PatchManager:
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
self._apply_moe_expert_quantization_patch()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
@@ -135,6 +136,10 @@ class PatchManager:
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches right after model build, before post-load setup."""
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
self._apply_llama_flash_attn_patches(model)
|
||||
@@ -170,9 +175,14 @@ class PatchManager:
|
||||
|
||||
patch_parallelism_config()
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import (
|
||||
patch_accelerate_fsdp2,
|
||||
patch_tied_keys_for_meta_device,
|
||||
)
|
||||
|
||||
patch_accelerate_fsdp2()
|
||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||
patch_tied_keys_for_meta_device()
|
||||
if self.cfg.rl:
|
||||
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
||||
|
||||
@@ -352,15 +362,54 @@ class PatchManager:
|
||||
if (
|
||||
self.cfg.fsdp_config
|
||||
and str(self.cfg.fsdp_version) == "2"
|
||||
and self.cfg.adapter == "qlora"
|
||||
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
|
||||
):
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_init_dtype_attrs_patch,
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
apply_linear8bitlt_save_patch,
|
||||
)
|
||||
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
apply_init_dtype_attrs_patch()
|
||||
if self.cfg.load_in_8bit:
|
||||
apply_linear8bitlt_save_patch()
|
||||
|
||||
def _apply_moe_expert_quantization_patch(self):
|
||||
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
||||
if not self.cfg.quantize_moe_experts:
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
patch_moe_quantization_on_load,
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
patch_moe_quantization_on_load(self.cfg)
|
||||
patch_peft_target_parameters_matching()
|
||||
|
||||
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
||||
"""Log quantization results and set model flag for downstream use."""
|
||||
import torch
|
||||
|
||||
model._moe_experts_quantized = False
|
||||
if self.cfg.quantize_moe_experts:
|
||||
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
|
||||
|
||||
count = get_moe_quantized_count()
|
||||
if count > 0:
|
||||
import gc
|
||||
|
||||
model._moe_experts_quantized = True
|
||||
LOG.info(
|
||||
"Quantized %d MoE expert parameter(s) to %s during model loading",
|
||||
count,
|
||||
"4-bit" if self.cfg.load_in_4bit else "8-bit",
|
||||
)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
|
||||
@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
state_dict: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.state_dict()
|
||||
|
||||
@@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True):
|
||||
)
|
||||
elif self.is_fsdp2:
|
||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
|
||||
param = param.full_tensor()
|
||||
if isinstance(param, DTensor):
|
||||
param = param.full_tensor()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
state_dict[param_name] = param.cpu()
|
||||
torch.distributed.barrier()
|
||||
@@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True):
|
||||
return state_dict
|
||||
|
||||
|
||||
def patch_peft_param_wrapper_for_fsdp2():
|
||||
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
|
||||
|
||||
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
|
||||
delta_weight to the base weight W inside _LoraParameterProxy.forward().
|
||||
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
|
||||
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
|
||||
|
||||
This patch promotes the non-DTensor operand to match the DTensor's spec
|
||||
using DTensor.from_local(), which is free for Replicate placement (just
|
||||
metadata wrapping, no communication).
|
||||
"""
|
||||
from peft.tuners.lora.layer import _LoraParameterProxy
|
||||
|
||||
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
|
||||
return
|
||||
|
||||
_original_forward = _LoraParameterProxy.forward
|
||||
|
||||
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
|
||||
def _patched_forward(self, W):
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
delta = self.delta_weight
|
||||
w_is_dt = isinstance(W, DTensor)
|
||||
d_is_dt = isinstance(delta, DTensor)
|
||||
|
||||
with torch.nn.utils.parametrize.cached():
|
||||
if w_is_dt == d_is_dt:
|
||||
return W + delta
|
||||
if w_is_dt:
|
||||
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
|
||||
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
|
||||
|
||||
_LoraParameterProxy.forward = _patched_forward
|
||||
_LoraParameterProxy._axolotl_fsdp2_patched = True
|
||||
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
|
||||
|
||||
|
||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||
"""Helper function to process LoRA modules for FSDP2."""
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
|
||||
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
|
||||
# The parent decoder layer's FSDP wrapper handles unsharding them.
|
||||
# TODO: review if we even need to shard them separately in first place.
|
||||
if isinstance(module, ParamWrapper):
|
||||
return False
|
||||
|
||||
log_bias_dtype_mismatch = False
|
||||
|
||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||
@@ -327,6 +377,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
is_peft_model = isinstance(model, PeftModel)
|
||||
|
||||
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
|
||||
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
|
||||
if is_peft_model:
|
||||
from peft.tuners.lora.layer import ParamWrapper
|
||||
|
||||
if any(isinstance(m, ParamWrapper) for m in model.modules()):
|
||||
patch_peft_param_wrapper_for_fsdp2()
|
||||
|
||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||
log_bias_dtype_mismatch = False
|
||||
if auto_wrap_policy is not None:
|
||||
@@ -376,6 +434,43 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def patch_tied_keys_for_meta_device():
|
||||
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
|
||||
|
||||
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
|
||||
grouped as "tied". Skipping them is safe since they have no real storage.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
|
||||
param_pointers = defaultdict(list)
|
||||
for param_name, param_value in self.state_dict().items():
|
||||
if param_value.is_meta:
|
||||
continue
|
||||
param_pointers[param_value.data_ptr()].append(param_name)
|
||||
|
||||
tied_param_names = [
|
||||
names
|
||||
for names in param_pointers.values()
|
||||
if len(names) > 1
|
||||
and not any(name in self.all_tied_weights_keys.keys() for name in names)
|
||||
and not all(name in missing_keys for name in names)
|
||||
]
|
||||
|
||||
tied_weights_keys_by_pointers = {
|
||||
param_name: group[0]
|
||||
for group in tied_param_names
|
||||
for param_name in group[1:]
|
||||
}
|
||||
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
|
||||
|
||||
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
|
||||
_patched_adjust_tied_keys_with_tied_pointers
|
||||
)
|
||||
|
||||
|
||||
def patch_accelerate_fsdp2():
|
||||
import accelerate
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""
|
||||
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
||||
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
|
||||
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||
|
||||
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
||||
Params4bit parameters.
|
||||
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
|
||||
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
|
||||
metadata through the FSDP2 shard/unshard cycle.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
@@ -17,6 +18,8 @@ LOG = get_logger(__name__)
|
||||
|
||||
def apply_init_sharded_param_patch():
|
||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
|
||||
return
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
@@ -41,9 +44,20 @@ def apply_init_sharded_param_patch():
|
||||
bnb_quantized=param.bnb_quantized,
|
||||
)
|
||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||
elif isinstance(param, bnb.nn.modules.Int8Params):
|
||||
self.sharded_param = bnb.nn.modules.Int8Params(
|
||||
data=sharded_param,
|
||||
requires_grad=param.requires_grad,
|
||||
has_fp16_weights=param.has_fp16_weights,
|
||||
CB=None,
|
||||
SCB=param.SCB,
|
||||
)
|
||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||
else:
|
||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||
self.sharded_param = nn.Parameter(
|
||||
self.to_sharded_dtensor(sharded_param),
|
||||
requires_grad=param.requires_grad,
|
||||
)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_param_creation in original_source:
|
||||
@@ -73,6 +87,7 @@ def apply_init_sharded_param_patch():
|
||||
|
||||
# Replace the method
|
||||
FSDPParam._init_sharded_param = patched_init_sharded_param
|
||||
apply_init_sharded_param_patch._axolotl_patched = True
|
||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||
@@ -80,6 +95,8 @@ def apply_init_sharded_param_patch():
|
||||
|
||||
def apply_init_unsharded_param_patch():
|
||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
|
||||
return
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
@@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch():
|
||||
module=local_tensor.module,
|
||||
bnb_quantized=local_tensor.bnb_quantized,
|
||||
)
|
||||
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
|
||||
self._unsharded_param = bnb.nn.modules.Int8Params(
|
||||
data=unsharded_param,
|
||||
requires_grad=self.sharded_param.requires_grad,
|
||||
has_fp16_weights=local_tensor.has_fp16_weights,
|
||||
CB=unsharded_param,
|
||||
SCB=local_tensor.SCB,
|
||||
)
|
||||
else:
|
||||
self._unsharded_param = nn.Parameter(
|
||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||
@@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch():
|
||||
|
||||
# Replace the method
|
||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
||||
apply_init_unsharded_param_patch._axolotl_patched = True
|
||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for patching")
|
||||
|
||||
|
||||
def apply_linear8bitlt_save_patch():
|
||||
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
|
||||
|
||||
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
|
||||
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
|
||||
doesn't proxy custom attribute access to its _local_tensor. This patch
|
||||
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
|
||||
"""
|
||||
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
|
||||
return
|
||||
import bitsandbytes as bnb
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
|
||||
|
||||
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
|
||||
weight = self._parameters["weight"]
|
||||
unwrapped = False
|
||||
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
|
||||
self._parameters["weight"] = weight._local_tensor
|
||||
unwrapped = True
|
||||
try:
|
||||
original_save(self, destination, prefix, keep_vars)
|
||||
finally:
|
||||
if unwrapped:
|
||||
self._parameters["weight"] = weight
|
||||
|
||||
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
|
||||
apply_linear8bitlt_save_patch._axolotl_patched = True
|
||||
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
|
||||
|
||||
|
||||
def apply_init_dtype_attrs_patch():
|
||||
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
|
||||
|
||||
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
|
||||
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
|
||||
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
|
||||
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
|
||||
|
||||
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
|
||||
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
|
||||
without extensions.
|
||||
"""
|
||||
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
|
||||
return
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
|
||||
|
||||
def patched_init_dtype_attrs(self, mp_policy):
|
||||
original_init_dtype_attrs(self, mp_policy)
|
||||
# Skip casting non-float quantized params (uint8/int8) without FSDP2
|
||||
# extensions — the parametrization chain handles dequantization.
|
||||
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
|
||||
local = self.sharded_param
|
||||
if hasattr(local, "_local_tensor"):
|
||||
local = local._local_tensor
|
||||
if not hasattr(local, "fsdp_pre_all_gather"):
|
||||
self.param_dtype = None
|
||||
|
||||
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
|
||||
apply_init_dtype_attrs_patch._axolotl_patched = True
|
||||
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")
|
||||
|
||||
@@ -9,6 +9,11 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
||||
except ImportError:
|
||||
fla_causal_conv1d = None
|
||||
|
||||
|
||||
def get_cu_seqlens(position_ids):
|
||||
"""
|
||||
@@ -137,6 +142,11 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
and cache_position is not None
|
||||
)
|
||||
|
||||
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
|
||||
cu_seqlens = None
|
||||
if not use_precomputed_states and position_ids is not None:
|
||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||
|
||||
# getting projected states from cache if it exists
|
||||
if cache_params is not None:
|
||||
conv_state = cache_params.conv_states[self.layer_idx]
|
||||
@@ -151,12 +161,11 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||
)
|
||||
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
|
||||
|
||||
if use_precomputed_states:
|
||||
# 2. Convolution sequence transformation
|
||||
# NOTE: the conv state is updated in `causal_conv1d_update`
|
||||
# Inference single-token path: causal_conv1d_update expects [B, D, T]
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = self.causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
@@ -164,24 +173,41 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
else:
|
||||
if cache_params is not None:
|
||||
# Cache state expects [B, D, T] for the inference update path
|
||||
mixed_qkv_t = mixed_qkv.transpose(1, 2)
|
||||
conv_state = F.pad(
|
||||
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
||||
mixed_qkv_t,
|
||||
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
|
||||
)
|
||||
cache_params.conv_states[self.layer_idx] = conv_state
|
||||
if self.causal_conv1d_fn is not None:
|
||||
mixed_qkv = self.causal_conv1d_fn(
|
||||
|
||||
if fla_causal_conv1d is not None:
|
||||
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
|
||||
mixed_qkv, _ = fla_causal_conv1d(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
else:
|
||||
# PyTorch fallback (no cu_seqlens support)
|
||||
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
|
||||
raise RuntimeError(
|
||||
"Packed sequences require fla.modules.convolution.causal_conv1d "
|
||||
"(cu_seqlens support). Install flash-linear-attention or disable packing."
|
||||
)
|
||||
LOG.warning_once(
|
||||
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
|
||||
)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
# mixed_qkv is [B, T, D] in all paths
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
@@ -203,7 +229,6 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
|
||||
if not use_precomputed_states:
|
||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
|
||||
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||
|
||||
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
||||
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
||||
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
||||
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
||||
"""
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn.utils.parametrize as P
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
# Module-level state for the loading-time quantization patch.
|
||||
_moe_load_state = {
|
||||
"count": 0,
|
||||
"mode": "4bit",
|
||||
"quant_type": "nf4",
|
||||
"compress_statistics": True,
|
||||
"patched": False,
|
||||
}
|
||||
|
||||
|
||||
class Bnb8bitParametrization(torch.nn.Module):
|
||||
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
||||
|
||||
def __init__(self, row_stats: torch.Tensor):
|
||||
super().__init__()
|
||||
self.register_buffer("row_stats", row_stats)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
||||
orig_shape = quantized_param.shape
|
||||
if quantized_param.ndim > 2:
|
||||
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
||||
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
|
||||
return result.reshape(orig_shape)
|
||||
|
||||
|
||||
def _enable_parametrization_cache(module, inputs):
|
||||
P._cache_enabled += 1
|
||||
|
||||
|
||||
def _disable_parametrization_cache(module, inputs, output):
|
||||
P._cache_enabled -= 1
|
||||
if not P._cache_enabled:
|
||||
P._cache = {}
|
||||
|
||||
|
||||
def replace_parameter_8bit(module, param_name):
|
||||
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
|
||||
original_param = getattr(module, param_name)
|
||||
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
|
||||
original_param.data.to(torch.float16)
|
||||
)
|
||||
|
||||
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
|
||||
del original_param
|
||||
|
||||
P.register_parametrization(
|
||||
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
|
||||
)
|
||||
|
||||
# Cache dequantized values during forward to avoid redundant dequantization.
|
||||
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
|
||||
module.register_forward_pre_hook(_enable_parametrization_cache)
|
||||
module.register_forward_hook(_disable_parametrization_cache)
|
||||
module._axolotl_8bit_hooks_registered = True
|
||||
|
||||
|
||||
def patch_moe_quantization_on_load(cfg):
|
||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
||||
|
||||
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
||||
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
||||
"""
|
||||
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
||||
_moe_load_state["mode"] = mode
|
||||
_moe_load_state["count"] = 0
|
||||
|
||||
if _moe_load_state["patched"]:
|
||||
LOG.debug("MoE loading-time quantization patch already active")
|
||||
return
|
||||
|
||||
import transformers.core_model_loading
|
||||
import transformers.modeling_utils
|
||||
|
||||
if mode == "4bit":
|
||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||
|
||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||
if compress_statistics is None:
|
||||
compress_statistics = True
|
||||
|
||||
_moe_load_state["quant_type"] = quant_type
|
||||
_moe_load_state["compress_statistics"] = compress_statistics
|
||||
|
||||
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
|
||||
# size for all params, defeating our on-load quantization VRAM savings.
|
||||
def _noop_warmup(*args, **kwargs):
|
||||
pass
|
||||
|
||||
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||
|
||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||
|
||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
if "expert" not in target_name.lower():
|
||||
LOG.debug(
|
||||
"Skipping non-expert 3D param: %s (shape=%s)",
|
||||
target_name,
|
||||
list(param_value.shape),
|
||||
)
|
||||
return
|
||||
|
||||
if _moe_load_state["mode"] == "4bit":
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
else:
|
||||
replace_parameter_8bit(mod, pname)
|
||||
_moe_load_state["count"] += 1
|
||||
|
||||
# Release the bf16 tensor so CUDA memory is freed immediately.
|
||||
param_value.data = torch.empty(0, device="cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||
_moe_load_state["patched"] = True
|
||||
|
||||
|
||||
def get_moe_quantized_count():
|
||||
"""Return the number of expert parameters quantized during loading."""
|
||||
return _moe_load_state["count"]
|
||||
|
||||
|
||||
def patch_peft_target_parameters_matching():
|
||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||
return
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
original_inject = BaseTuner._inject_parameters
|
||||
|
||||
def _patched_inject_parameters(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
):
|
||||
# Patch target_parameters to use full paths for parametrized modules
|
||||
original_targets = list(peft_config.target_parameters)
|
||||
expanded = set(original_targets)
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
if not hasattr(module, "parametrizations"):
|
||||
continue
|
||||
for target in original_targets:
|
||||
mod_path, _, param_name = target.rpartition(".")
|
||||
if (
|
||||
module_name == mod_path or module_name.endswith("." + mod_path)
|
||||
) and hasattr(module, param_name):
|
||||
expanded.add(f"{module_name}.{param_name}")
|
||||
|
||||
peft_config.target_parameters = sorted(expanded)
|
||||
try:
|
||||
return original_inject(
|
||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||
)
|
||||
finally:
|
||||
peft_config.target_parameters = original_targets
|
||||
|
||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||
@@ -48,9 +48,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
):
|
||||
# check if message_property_mappings is None or empty dict
|
||||
if message_property_mappings is None or (not message_property_mappings):
|
||||
default_message_property_mappings_keys = ["role", "content", "tool"]
|
||||
message_property_mappings = {
|
||||
prop: prop for prop in default_message_property_mappings_keys
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
if template_thinking_key and field_thinking:
|
||||
message_property_mappings[template_thinking_key] = field_thinking
|
||||
|
||||
@@ -629,6 +629,17 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
quantize_moe_experts: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Quantize MoE expert weights on load to reduce VRAM. "
|
||||
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
|
||||
"Requires CUDA (not compatible with ROCm or other backends). "
|
||||
"Note: total parameter count may be reported incorrectly when enabled "
|
||||
"(trainable param count is correct)."
|
||||
},
|
||||
)
|
||||
|
||||
scaling_softmax: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -1289,6 +1300,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_quantize_moe_experts(cls, data):
|
||||
if data.get("quantize_moe_experts"):
|
||||
if data.get("adapter") not in ("lora", "qlora"):
|
||||
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
|
||||
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
|
||||
raise ValueError(
|
||||
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
|
||||
)
|
||||
if (
|
||||
data.get("capabilities")
|
||||
and data["capabilities"].get("compute_capability")
|
||||
and not data["capabilities"]["compute_capability"].startswith("sm_")
|
||||
):
|
||||
raise ValueError(
|
||||
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_auto_enable_lora_kernels(cls, data):
|
||||
|
||||
@@ -209,6 +209,19 @@ class LoraConfig(BaseModel):
|
||||
data["lora_dropout"] = 0.0
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_lora_target_parameters_dropout(self):
|
||||
if (
|
||||
self.lora_target_parameters
|
||||
and self.lora_dropout
|
||||
and self.lora_dropout != 0.0
|
||||
):
|
||||
raise ValueError(
|
||||
"lora_dropout must be 0 when lora_target_parameters is set. "
|
||||
"PEFT's ParamWrapper does not support lora_dropout != 0."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
142
tests/utils/schemas/validation/test_moe_quant.py
Normal file
142
tests/utils/schemas/validation/test_moe_quant.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def gpu_caps():
|
||||
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def env_caps():
|
||||
return {"torch_version": "2.7.0"}
|
||||
|
||||
|
||||
class TestQuantizeMoeExpertsValidation:
|
||||
"""Test suite for quantize_moe_experts config validator."""
|
||||
|
||||
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts without adapter should fail."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires adapter"):
|
||||
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
|
||||
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
adapter="lora",
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
|
||||
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
|
||||
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts with qlora + 4bit should pass."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
adapter="qlora",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is True
|
||||
|
||||
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts with lora + 8bit should pass."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
adapter="lora",
|
||||
load_in_8bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is True
|
||||
|
||||
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts=false should not check adapter/quantization."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=False,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is False
|
||||
|
||||
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts should default to false."""
|
||||
cfg = DictDefault({}) | min_base_cfg
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is False
|
||||
|
||||
|
||||
class TestLoraTargetParametersDropout:
|
||||
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
|
||||
|
||||
def test_rejects_nonzero_dropout(self, min_base_cfg):
|
||||
"""lora_dropout > 0 with lora_target_parameters should fail."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||
lora_dropout=0.1,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
with pytest.raises(ValueError, match="lora_dropout must be 0"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_zero_dropout_passes(self, min_base_cfg):
|
||||
"""lora_dropout=0 with lora_target_parameters should pass."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||
lora_dropout=0.0,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg)
|
||||
assert result["lora_dropout"] == 0.0
|
||||
|
||||
|
||||
class TestPeftPatchIdempotency:
|
||||
"""Test that patch_peft_target_parameters_matching is idempotent."""
|
||||
|
||||
def test_double_call_does_not_stack_wrappers(self):
|
||||
"""Calling patch twice should not double-wrap _inject_parameters."""
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
original = BaseTuner._inject_parameters
|
||||
try:
|
||||
patch_peft_target_parameters_matching()
|
||||
first_patched = BaseTuner._inject_parameters
|
||||
patch_peft_target_parameters_matching()
|
||||
second_patched = BaseTuner._inject_parameters
|
||||
# Should be same function, not double-wrapped
|
||||
assert first_patched is second_patched
|
||||
finally:
|
||||
BaseTuner._inject_parameters = original
|
||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||
Reference in New Issue
Block a user