Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4a3b618e7 | ||
|
|
b6b8db805a | ||
|
|
653f90be25 | ||
|
|
945c8aeb10 | ||
|
|
e672d37f33 |
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -51,6 +51,14 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -173,6 +181,14 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
36
.github/workflows/tests.yml
vendored
36
.github/workflows/tests.yml
vendored
@@ -54,13 +54,13 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11", "3.12"]
|
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||||
exclude:
|
# exclude:
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.8.0"
|
# pytorch_version: "2.8.0"
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.9.0"
|
# pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -149,13 +149,13 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11", "3.12"]
|
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||||
exclude:
|
# exclude:
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.8.0"
|
# pytorch_version: "2.8.0"
|
||||||
- python_version: "3.12"
|
# - python_version: "3.14"
|
||||||
pytorch_version: "2.9.0"
|
# pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -326,6 +326,12 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.10.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 130
|
- cuda: 130
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -371,7 +377,7 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- cuda: 129
|
- cuda: 129
|
||||||
cuda_version: 12.9.1
|
cuda_version: 12.9.1
|
||||||
python_version: "3.12"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
77
examples/glm4.7-flash/README.md
Normal file
77
examples/glm4.7-flash/README.md
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
||||||
|
|
||||||
|
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# QLoRA
|
||||||
|
# - no target experts (1x48GB @ ~24GiB/GPU)
|
||||||
|
# - target experts (1x48GB @ ~34GiB/GPU)
|
||||||
|
axolotl train examples/glm4.7-flash/qlora.yaml
|
||||||
|
|
||||||
|
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
||||||
|
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LoRA
|
||||||
|
# - no target experts (1x48GB @ ~35GiB/GPU)
|
||||||
|
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
||||||
|
axolotl train examples/glm4.7-flash/lora.yaml
|
||||||
|
|
||||||
|
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
||||||
|
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Expert LoRA
|
||||||
|
|
||||||
|
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
|
||||||
|
|
||||||
|
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
lora_target_parameters:
|
||||||
|
- mlp.experts.gate_up_proj
|
||||||
|
- mlp.experts.down_proj
|
||||||
|
# - mlp.gate.weight # router, untested but should work, not normally targeted
|
||||||
|
```
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
|
||||||
|
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
|
||||||
|
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
|
||||||
|
- **lora_target_linear**: Incompatible for this model.
|
||||||
|
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
||||||
|
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
||||||
|
- `temperature: 1.0`
|
||||||
|
- `top_p: 0.95`
|
||||||
|
- `max_new_tokens: 131072`
|
||||||
|
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
||||||
|
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
65
examples/glm4.7-flash/lora.yaml
Normal file
65
examples/glm4.7-flash/lora.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
75
examples/glm4.7-flash/lora_fsdp.yaml
Normal file
75
examples/glm4.7-flash/lora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
65
examples/glm4.7-flash/qlora.yaml
Normal file
65
examples/glm4.7-flash/qlora.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
75
examples/glm4.7-flash/qlora_fsdp.yaml
Normal file
75
examples/glm4.7-flash/qlora_fsdp.yaml
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
base_model: zai-org/GLM-4.7-Flash
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
# Uncomment to also target MoE expert weights:
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
# LoRA kernels incompatible with DSA attention
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: false
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
sharding_strategy: FULL_SHARD
|
||||||
|
reshard_after_forward: true
|
||||||
|
activation_checkpointing: true
|
||||||
@@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Install Qwen3-Next transformers commit
|
|
||||||
```bash
|
|
||||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Install FLA for improved performance
|
3. Install FLA for improved performance
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Run the finetuning example:
|
4. Run the finetuning example:
|
||||||
@@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
|||||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 45.62 GiB VRAM.
|
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ plugins:
|
|||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
|
|
||||||
|
quantize_moe_experts: true
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
@@ -25,7 +27,7 @@ sample_packing: true
|
|||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 8
|
lora_alpha: 8
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- linear_attn.in_proj_ba
|
- linear_attn.in_proj_ba
|
||||||
- linear_attn.in_proj_qkvz
|
- linear_attn.in_proj_qkvz
|
||||||
@@ -34,12 +36,19 @@ lora_target_modules:
|
|||||||
- shared_expert.down_proj
|
- shared_expert.down_proj
|
||||||
- shared_expert.gate_proj
|
- shared_expert.gate_proj
|
||||||
- shared_expert_gate
|
- shared_expert_gate
|
||||||
- mlp.gate
|
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
- k_proj
|
- k_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
|
# lora_target_parameters:
|
||||||
|
# - mlp.experts.gate_up_proj
|
||||||
|
# - mlp.experts.down_proj
|
||||||
|
|
||||||
|
lora_mlp_kernel: false
|
||||||
|
lora_qkv_kernel: false
|
||||||
|
lora_o_kernel: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
|
|||||||
@@ -8,13 +8,15 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 24.9 GiB VRAM.
|
This config uses about 24.9 GiB VRAM (w/o CCE).
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
@@ -29,10 +31,6 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
base_model: arcee-ai/Trinity-Nano-Preview
|
base_model: arcee-ai/Trinity-Nano-Preview
|
||||||
trust_remote_code: true
|
|
||||||
revision_of_model: 2ee94b0
|
revision_of_model: 2ee94b0
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
|||||||
@@ -63,3 +63,5 @@ docstring-code-format = false
|
|||||||
|
|
||||||
[tool.uv.extra-build-dependencies]
|
[tool.uv.extra-build-dependencies]
|
||||||
axolotl = ["huggingface_hub"]
|
axolotl = ["huggingface_hub"]
|
||||||
|
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
deepspeed = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,4 +18,7 @@ MOE_ARCH_BLOCK = {
|
|||||||
"gpt_oss": "GptOssDecoderLayer",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||||
"afmoe": "AfmoeMoE",
|
"afmoe": "AfmoeMoE",
|
||||||
|
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||||
|
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||||
|
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -720,12 +720,16 @@ class AxolotlTrainer(
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
|
||||||
# fix for Context Parallel save
|
# fix for Context Parallel save: CP eval invalidates tensor storage
|
||||||
if state_dict is None:
|
# pointers, so clone to CPU to get fresh valid storage for safetensors
|
||||||
state_dict = self.accelerator.get_state_dict(self.model)
|
if (
|
||||||
if state_dict is not None:
|
state_dict is not None
|
||||||
|
and self.axolotl_cfg
|
||||||
|
and self.axolotl_cfg.context_parallel_size
|
||||||
|
and self.axolotl_cfg.context_parallel_size > 1
|
||||||
|
):
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k: v.clone() if isinstance(v, torch.Tensor) else v
|
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
||||||
for k, v in state_dict.items()
|
for k, v in state_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -761,7 +765,11 @@ class AxolotlTrainer(
|
|||||||
metadata={"format": "pt"},
|
metadata={"format": "pt"},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
self.model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=state_dict,
|
||||||
|
is_main_process=self.accelerator.is_main_process,
|
||||||
|
)
|
||||||
|
|
||||||
if self.processing_class is not None:
|
if self.processing_class is not None:
|
||||||
self.processing_class.save_pretrained(output_dir)
|
self.processing_class.save_pretrained(output_dir)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -88,9 +88,9 @@ plugins:
|
|||||||
- qwen2_vl
|
- qwen2_vl
|
||||||
- qwen3
|
- qwen3
|
||||||
- qwen3_5
|
- qwen3_5
|
||||||
|
- qwen3_5_text
|
||||||
- qwen3_5_moe
|
- qwen3_5_moe
|
||||||
- qwen3_5_moe_vl
|
- qwen3_5_moe_text
|
||||||
- qwen3_5_vl
|
|
||||||
- qwen3_moe
|
- qwen3_moe
|
||||||
- qwen3_next
|
- qwen3_next
|
||||||
- qwen3_vl
|
- qwen3_vl
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ This works for any MoE model in transformers that uses a `SparseMoeBlock` class
|
|||||||
|
|
||||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
||||||
|
|
||||||
|
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
||||||
|
|
||||||
## Note on MegaBlocks
|
## Note on MegaBlocks
|
||||||
|
|
||||||
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.
|
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if isinstance(param, Params4bit):
|
if isinstance(param, Params4bit) and param.quant_state is not None:
|
||||||
param.quant_state._orig_to = param.quant_state.to
|
param.quant_state._orig_to = param.quant_state.to
|
||||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||||
|
|
||||||
|
|||||||
@@ -172,7 +172,10 @@ class ModelLoader:
|
|||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||||
|
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
|
self.patch_manager.apply_post_model_build_patches(self.model)
|
||||||
|
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
# Post-build model configuration
|
# Post-build model configuration
|
||||||
@@ -860,6 +863,10 @@ class ModelLoader:
|
|||||||
# Make sure everything is in the same dtype
|
# Make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
|
if getattr(self.model, "_moe_experts_quantized", False):
|
||||||
|
# Parametrized expert tensors dequantize on access — would OOM.
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ class PatchManager:
|
|||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
|
self._apply_moe_expert_quantization_patch()
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
@@ -135,6 +136,10 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
patch_prepare_context_parallel_inputs()
|
||||||
|
|
||||||
|
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
||||||
|
"""Apply patches right after model build, before post-load setup."""
|
||||||
|
self._finalize_moe_expert_quantization(model)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
self._apply_llama_flash_attn_patches(model)
|
self._apply_llama_flash_attn_patches(model)
|
||||||
@@ -170,9 +175,14 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_parallelism_config()
|
patch_parallelism_config()
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
from axolotl.monkeypatch.accelerate.fsdp2 import (
|
||||||
|
patch_accelerate_fsdp2,
|
||||||
|
patch_tied_keys_for_meta_device,
|
||||||
|
)
|
||||||
|
|
||||||
patch_accelerate_fsdp2()
|
patch_accelerate_fsdp2()
|
||||||
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
|
patch_tied_keys_for_meta_device()
|
||||||
if self.cfg.rl:
|
if self.cfg.rl:
|
||||||
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
||||||
|
|
||||||
@@ -352,15 +362,54 @@ class PatchManager:
|
|||||||
if (
|
if (
|
||||||
self.cfg.fsdp_config
|
self.cfg.fsdp_config
|
||||||
and str(self.cfg.fsdp_version) == "2"
|
and str(self.cfg.fsdp_version) == "2"
|
||||||
and self.cfg.adapter == "qlora"
|
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
|
apply_init_dtype_attrs_patch,
|
||||||
apply_init_sharded_param_patch,
|
apply_init_sharded_param_patch,
|
||||||
apply_init_unsharded_param_patch,
|
apply_init_unsharded_param_patch,
|
||||||
|
apply_linear8bitlt_save_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_init_sharded_param_patch()
|
apply_init_sharded_param_patch()
|
||||||
apply_init_unsharded_param_patch()
|
apply_init_unsharded_param_patch()
|
||||||
|
apply_init_dtype_attrs_patch()
|
||||||
|
if self.cfg.load_in_8bit:
|
||||||
|
apply_linear8bitlt_save_patch()
|
||||||
|
|
||||||
|
def _apply_moe_expert_quantization_patch(self):
|
||||||
|
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
||||||
|
if not self.cfg.quantize_moe_experts:
|
||||||
|
return
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.moe_quant import (
|
||||||
|
patch_moe_quantization_on_load,
|
||||||
|
patch_peft_target_parameters_matching,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_moe_quantization_on_load(self.cfg)
|
||||||
|
patch_peft_target_parameters_matching()
|
||||||
|
|
||||||
|
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
||||||
|
"""Log quantization results and set model flag for downstream use."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
model._moe_experts_quantized = False
|
||||||
|
if self.cfg.quantize_moe_experts:
|
||||||
|
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
|
||||||
|
|
||||||
|
count = get_moe_quantized_count()
|
||||||
|
if count > 0:
|
||||||
|
import gc
|
||||||
|
|
||||||
|
model._moe_experts_quantized = True
|
||||||
|
LOG.info(
|
||||||
|
"Quantized %d MoE expert parameter(s) to %s during model loading",
|
||||||
|
count,
|
||||||
|
"4-bit" if self.cfg.load_in_4bit else "8-bit",
|
||||||
|
)
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
|
|||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
state_dict: Optional[dict] = None,
|
state_dict: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|||||||
@@ -150,13 +150,17 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
)
|
)
|
||||||
elif self.is_fsdp2:
|
elif self.is_fsdp2:
|
||||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
sharded_state_dict = model.state_dict()
|
sharded_state_dict = model.state_dict()
|
||||||
for param_name, param in sharded_state_dict.items():
|
for param_name, param in sharded_state_dict.items():
|
||||||
if param.is_cpu:
|
if param.is_cpu:
|
||||||
param = param.to(torch.device("cuda"))
|
param = param.to(torch.device("cuda"))
|
||||||
|
|
||||||
param = param.full_tensor()
|
if isinstance(param, DTensor):
|
||||||
|
param = param.full_tensor()
|
||||||
|
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
state_dict[param_name] = param.cpu()
|
state_dict[param_name] = param.cpu()
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
@@ -182,10 +186,56 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_param_wrapper_for_fsdp2():
|
||||||
|
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
|
||||||
|
|
||||||
|
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
|
||||||
|
delta_weight to the base weight W inside _LoraParameterProxy.forward().
|
||||||
|
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
|
||||||
|
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
|
||||||
|
|
||||||
|
This patch promotes the non-DTensor operand to match the DTensor's spec
|
||||||
|
using DTensor.from_local(), which is free for Replicate placement (just
|
||||||
|
metadata wrapping, no communication).
|
||||||
|
"""
|
||||||
|
from peft.tuners.lora.layer import _LoraParameterProxy
|
||||||
|
|
||||||
|
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_forward = _LoraParameterProxy.forward
|
||||||
|
|
||||||
|
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
|
||||||
|
def _patched_forward(self, W):
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
delta = self.delta_weight
|
||||||
|
w_is_dt = isinstance(W, DTensor)
|
||||||
|
d_is_dt = isinstance(delta, DTensor)
|
||||||
|
|
||||||
|
with torch.nn.utils.parametrize.cached():
|
||||||
|
if w_is_dt == d_is_dt:
|
||||||
|
return W + delta
|
||||||
|
if w_is_dt:
|
||||||
|
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
|
||||||
|
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
|
||||||
|
|
||||||
|
_LoraParameterProxy.forward = _patched_forward
|
||||||
|
_LoraParameterProxy._axolotl_fsdp2_patched = True
|
||||||
|
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
|
||||||
|
|
||||||
|
|
||||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||||
"""Helper function to process LoRA modules for FSDP2."""
|
"""Helper function to process LoRA modules for FSDP2."""
|
||||||
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
from torch.distributed.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
|
|
||||||
|
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
|
||||||
|
# The parent decoder layer's FSDP wrapper handles unsharding them.
|
||||||
|
# TODO: review if we even need to shard them separately in first place.
|
||||||
|
if isinstance(module, ParamWrapper):
|
||||||
|
return False
|
||||||
|
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
|
|
||||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
@@ -327,6 +377,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
|
|
||||||
is_peft_model = isinstance(model, PeftModel)
|
is_peft_model = isinstance(model, PeftModel)
|
||||||
|
|
||||||
|
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
|
||||||
|
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
|
||||||
|
if is_peft_model:
|
||||||
|
from peft.tuners.lora.layer import ParamWrapper
|
||||||
|
|
||||||
|
if any(isinstance(m, ParamWrapper) for m in model.modules()):
|
||||||
|
patch_peft_param_wrapper_for_fsdp2()
|
||||||
|
|
||||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
if auto_wrap_policy is not None:
|
if auto_wrap_policy is not None:
|
||||||
@@ -376,6 +434,43 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def patch_tied_keys_for_meta_device():
|
||||||
|
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
|
||||||
|
|
||||||
|
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
|
||||||
|
grouped as "tied". Skipping them is safe since they have no real storage.
|
||||||
|
"""
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
|
||||||
|
param_pointers = defaultdict(list)
|
||||||
|
for param_name, param_value in self.state_dict().items():
|
||||||
|
if param_value.is_meta:
|
||||||
|
continue
|
||||||
|
param_pointers[param_value.data_ptr()].append(param_name)
|
||||||
|
|
||||||
|
tied_param_names = [
|
||||||
|
names
|
||||||
|
for names in param_pointers.values()
|
||||||
|
if len(names) > 1
|
||||||
|
and not any(name in self.all_tied_weights_keys.keys() for name in names)
|
||||||
|
and not all(name in missing_keys for name in names)
|
||||||
|
]
|
||||||
|
|
||||||
|
tied_weights_keys_by_pointers = {
|
||||||
|
param_name: group[0]
|
||||||
|
for group in tied_param_names
|
||||||
|
for param_name in group[1:]
|
||||||
|
}
|
||||||
|
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
|
||||||
|
|
||||||
|
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
|
||||||
|
_patched_adjust_tied_keys_with_tied_pointers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_accelerate_fsdp2():
|
def patch_accelerate_fsdp2():
|
||||||
import accelerate
|
import accelerate
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
|
||||||
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||||
|
|
||||||
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
|
||||||
Params4bit parameters.
|
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
|
||||||
|
metadata through the FSDP2 shard/unshard cycle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
@@ -17,6 +18,8 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
def apply_init_sharded_param_patch():
|
def apply_init_sharded_param_patch():
|
||||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||||
|
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
# Get original source
|
# Get original source
|
||||||
@@ -41,9 +44,20 @@ def apply_init_sharded_param_patch():
|
|||||||
bnb_quantized=param.bnb_quantized,
|
bnb_quantized=param.bnb_quantized,
|
||||||
)
|
)
|
||||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
|
elif isinstance(param, bnb.nn.modules.Int8Params):
|
||||||
|
self.sharded_param = bnb.nn.modules.Int8Params(
|
||||||
|
data=sharded_param,
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
has_fp16_weights=param.has_fp16_weights,
|
||||||
|
CB=None,
|
||||||
|
SCB=param.SCB,
|
||||||
|
)
|
||||||
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
else:
|
else:
|
||||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
self.sharded_param = nn.Parameter(
|
||||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
self.to_sharded_dtensor(sharded_param),
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
)"""
|
||||||
|
|
||||||
# Apply the replacement
|
# Apply the replacement
|
||||||
if original_param_creation in original_source:
|
if original_param_creation in original_source:
|
||||||
@@ -73,6 +87,7 @@ def apply_init_sharded_param_patch():
|
|||||||
|
|
||||||
# Replace the method
|
# Replace the method
|
||||||
FSDPParam._init_sharded_param = patched_init_sharded_param
|
FSDPParam._init_sharded_param = patched_init_sharded_param
|
||||||
|
apply_init_sharded_param_patch._axolotl_patched = True
|
||||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||||
else:
|
else:
|
||||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||||
@@ -80,6 +95,8 @@ def apply_init_sharded_param_patch():
|
|||||||
|
|
||||||
def apply_init_unsharded_param_patch():
|
def apply_init_unsharded_param_patch():
|
||||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||||
|
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
# Get original source
|
# Get original source
|
||||||
@@ -105,6 +122,14 @@ def apply_init_unsharded_param_patch():
|
|||||||
module=local_tensor.module,
|
module=local_tensor.module,
|
||||||
bnb_quantized=local_tensor.bnb_quantized,
|
bnb_quantized=local_tensor.bnb_quantized,
|
||||||
)
|
)
|
||||||
|
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
|
||||||
|
self._unsharded_param = bnb.nn.modules.Int8Params(
|
||||||
|
data=unsharded_param,
|
||||||
|
requires_grad=self.sharded_param.requires_grad,
|
||||||
|
has_fp16_weights=local_tensor.has_fp16_weights,
|
||||||
|
CB=unsharded_param,
|
||||||
|
SCB=local_tensor.SCB,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._unsharded_param = nn.Parameter(
|
self._unsharded_param = nn.Parameter(
|
||||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||||
@@ -138,6 +163,74 @@ def apply_init_unsharded_param_patch():
|
|||||||
|
|
||||||
# Replace the method
|
# Replace the method
|
||||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
||||||
|
apply_init_unsharded_param_patch._axolotl_patched = True
|
||||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||||
else:
|
else:
|
||||||
LOG.warning("Could not find target code for patching")
|
LOG.warning("Could not find target code for patching")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_linear8bitlt_save_patch():
|
||||||
|
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
|
||||||
|
|
||||||
|
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
|
||||||
|
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
|
||||||
|
doesn't proxy custom attribute access to its _local_tensor. This patch
|
||||||
|
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
|
||||||
|
"""
|
||||||
|
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
|
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
|
||||||
|
|
||||||
|
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
|
||||||
|
weight = self._parameters["weight"]
|
||||||
|
unwrapped = False
|
||||||
|
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
|
||||||
|
self._parameters["weight"] = weight._local_tensor
|
||||||
|
unwrapped = True
|
||||||
|
try:
|
||||||
|
original_save(self, destination, prefix, keep_vars)
|
||||||
|
finally:
|
||||||
|
if unwrapped:
|
||||||
|
self._parameters["weight"] = weight
|
||||||
|
|
||||||
|
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
|
||||||
|
apply_linear8bitlt_save_patch._axolotl_patched = True
|
||||||
|
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_init_dtype_attrs_patch():
|
||||||
|
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
|
||||||
|
|
||||||
|
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
|
||||||
|
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
|
||||||
|
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
|
||||||
|
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
|
||||||
|
|
||||||
|
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
|
||||||
|
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
|
||||||
|
without extensions.
|
||||||
|
"""
|
||||||
|
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
|
||||||
|
|
||||||
|
def patched_init_dtype_attrs(self, mp_policy):
|
||||||
|
original_init_dtype_attrs(self, mp_policy)
|
||||||
|
# Skip casting non-float quantized params (uint8/int8) without FSDP2
|
||||||
|
# extensions — the parametrization chain handles dequantization.
|
||||||
|
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
|
||||||
|
local = self.sharded_param
|
||||||
|
if hasattr(local, "_local_tensor"):
|
||||||
|
local = local._local_tensor
|
||||||
|
if not hasattr(local, "fsdp_pre_all_gather"):
|
||||||
|
self.param_dtype = None
|
||||||
|
|
||||||
|
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
|
||||||
|
apply_init_dtype_attrs_patch._axolotl_patched = True
|
||||||
|
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")
|
||||||
|
|||||||
@@ -9,6 +9,11 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
||||||
|
except ImportError:
|
||||||
|
fla_causal_conv1d = None
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens(position_ids):
|
def get_cu_seqlens(position_ids):
|
||||||
"""
|
"""
|
||||||
@@ -137,6 +142,11 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
and cache_position is not None
|
and cache_position is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
|
||||||
|
cu_seqlens = None
|
||||||
|
if not use_precomputed_states and position_ids is not None:
|
||||||
|
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||||
|
|
||||||
# getting projected states from cache if it exists
|
# getting projected states from cache if it exists
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
conv_state = cache_params.conv_states[self.layer_idx]
|
conv_state = cache_params.conv_states[self.layer_idx]
|
||||||
@@ -151,12 +161,11 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
|
||||||
|
|
||||||
if use_precomputed_states:
|
if use_precomputed_states:
|
||||||
# 2. Convolution sequence transformation
|
# Inference single-token path: causal_conv1d_update expects [B, D, T]
|
||||||
# NOTE: the conv state is updated in `causal_conv1d_update`
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
mixed_qkv = self.causal_conv1d_update(
|
mixed_qkv = self.causal_conv1d_update(
|
||||||
mixed_qkv,
|
mixed_qkv,
|
||||||
conv_state,
|
conv_state,
|
||||||
@@ -164,24 +173,41 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
)
|
)
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
else:
|
else:
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
|
# Cache state expects [B, D, T] for the inference update path
|
||||||
|
mixed_qkv_t = mixed_qkv.transpose(1, 2)
|
||||||
conv_state = F.pad(
|
conv_state = F.pad(
|
||||||
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
mixed_qkv_t,
|
||||||
|
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
|
||||||
)
|
)
|
||||||
cache_params.conv_states[self.layer_idx] = conv_state
|
cache_params.conv_states[self.layer_idx] = conv_state
|
||||||
if self.causal_conv1d_fn is not None:
|
|
||||||
mixed_qkv = self.causal_conv1d_fn(
|
if fla_causal_conv1d is not None:
|
||||||
|
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
|
||||||
|
mixed_qkv, _ = fla_causal_conv1d(
|
||||||
x=mixed_qkv,
|
x=mixed_qkv,
|
||||||
weight=self.conv1d.weight.squeeze(1),
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
bias=self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
seq_idx=None,
|
cu_seqlens=cu_seqlens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# PyTorch fallback (no cu_seqlens support)
|
||||||
|
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Packed sequences require fla.modules.convolution.causal_conv1d "
|
||||||
|
"(cu_seqlens support). Install flash-linear-attention or disable packing."
|
||||||
|
)
|
||||||
|
LOG.warning_once(
|
||||||
|
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
|
||||||
|
)
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
|
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
# mixed_qkv is [B, T, D] in all paths
|
||||||
query, key, value = torch.split(
|
query, key, value = torch.split(
|
||||||
mixed_qkv,
|
mixed_qkv,
|
||||||
[
|
[
|
||||||
@@ -203,7 +229,6 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
|
|
||||||
if not use_precomputed_states:
|
if not use_precomputed_states:
|
||||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
|
||||||
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
188
src/axolotl/monkeypatch/moe_quant.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
||||||
|
|
||||||
|
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
||||||
|
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
||||||
|
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
||||||
|
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
||||||
|
"""
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import torch
|
||||||
|
import torch.nn.utils.parametrize as P
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
# Module-level state for the loading-time quantization patch.
|
||||||
|
_moe_load_state = {
|
||||||
|
"count": 0,
|
||||||
|
"mode": "4bit",
|
||||||
|
"quant_type": "nf4",
|
||||||
|
"compress_statistics": True,
|
||||||
|
"patched": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Bnb8bitParametrization(torch.nn.Module):
|
||||||
|
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
||||||
|
|
||||||
|
def __init__(self, row_stats: torch.Tensor):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("row_stats", row_stats)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
||||||
|
orig_shape = quantized_param.shape
|
||||||
|
if quantized_param.ndim > 2:
|
||||||
|
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
||||||
|
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
|
||||||
|
return result.reshape(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _enable_parametrization_cache(module, inputs):
|
||||||
|
P._cache_enabled += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _disable_parametrization_cache(module, inputs, output):
|
||||||
|
P._cache_enabled -= 1
|
||||||
|
if not P._cache_enabled:
|
||||||
|
P._cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def replace_parameter_8bit(module, param_name):
|
||||||
|
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
|
||||||
|
original_param = getattr(module, param_name)
|
||||||
|
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
|
||||||
|
original_param.data.to(torch.float16)
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
|
||||||
|
del original_param
|
||||||
|
|
||||||
|
P.register_parametrization(
|
||||||
|
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache dequantized values during forward to avoid redundant dequantization.
|
||||||
|
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
|
||||||
|
module.register_forward_pre_hook(_enable_parametrization_cache)
|
||||||
|
module.register_forward_hook(_disable_parametrization_cache)
|
||||||
|
module._axolotl_8bit_hooks_registered = True
|
||||||
|
|
||||||
|
|
||||||
|
def patch_moe_quantization_on_load(cfg):
|
||||||
|
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
||||||
|
|
||||||
|
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
||||||
|
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
||||||
|
"""
|
||||||
|
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
||||||
|
_moe_load_state["mode"] = mode
|
||||||
|
_moe_load_state["count"] = 0
|
||||||
|
|
||||||
|
if _moe_load_state["patched"]:
|
||||||
|
LOG.debug("MoE loading-time quantization patch already active")
|
||||||
|
return
|
||||||
|
|
||||||
|
import transformers.core_model_loading
|
||||||
|
import transformers.modeling_utils
|
||||||
|
|
||||||
|
if mode == "4bit":
|
||||||
|
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
||||||
|
|
||||||
|
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
||||||
|
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
||||||
|
if compress_statistics is None:
|
||||||
|
compress_statistics = True
|
||||||
|
|
||||||
|
_moe_load_state["quant_type"] = quant_type
|
||||||
|
_moe_load_state["compress_statistics"] = compress_statistics
|
||||||
|
|
||||||
|
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
|
||||||
|
# size for all params, defeating our on-load quantization VRAM savings.
|
||||||
|
def _noop_warmup(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
||||||
|
|
||||||
|
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||||
|
|
||||||
|
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||||
|
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||||
|
|
||||||
|
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||||
|
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||||
|
mod_path, _, pname = target_name.rpartition(".")
|
||||||
|
mod = model.get_submodule(mod_path) if mod_path else model
|
||||||
|
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||||
|
if "expert" not in target_name.lower():
|
||||||
|
LOG.debug(
|
||||||
|
"Skipping non-expert 3D param: %s (shape=%s)",
|
||||||
|
target_name,
|
||||||
|
list(param_value.shape),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if _moe_load_state["mode"] == "4bit":
|
||||||
|
replace_parameter_4bit(
|
||||||
|
mod,
|
||||||
|
pname,
|
||||||
|
compress_statistics=_moe_load_state["compress_statistics"],
|
||||||
|
quant_type=_moe_load_state["quant_type"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
replace_parameter_8bit(mod, pname)
|
||||||
|
_moe_load_state["count"] += 1
|
||||||
|
|
||||||
|
# Release the bf16 tensor so CUDA memory is freed immediately.
|
||||||
|
param_value.data = torch.empty(0, device="cpu")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||||
|
_moe_load_state["patched"] = True
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_quantized_count():
|
||||||
|
"""Return the number of expert parameters quantized during loading."""
|
||||||
|
return _moe_load_state["count"]
|
||||||
|
|
||||||
|
|
||||||
|
def patch_peft_target_parameters_matching():
|
||||||
|
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
||||||
|
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
||||||
|
return
|
||||||
|
from peft.tuners.tuners_utils import BaseTuner
|
||||||
|
|
||||||
|
original_inject = BaseTuner._inject_parameters
|
||||||
|
|
||||||
|
def _patched_inject_parameters(
|
||||||
|
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||||
|
):
|
||||||
|
# Patch target_parameters to use full paths for parametrized modules
|
||||||
|
original_targets = list(peft_config.target_parameters)
|
||||||
|
expanded = set(original_targets)
|
||||||
|
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
if not hasattr(module, "parametrizations"):
|
||||||
|
continue
|
||||||
|
for target in original_targets:
|
||||||
|
mod_path, _, param_name = target.rpartition(".")
|
||||||
|
if (
|
||||||
|
module_name == mod_path or module_name.endswith("." + mod_path)
|
||||||
|
) and hasattr(module, param_name):
|
||||||
|
expanded.add(f"{module_name}.{param_name}")
|
||||||
|
|
||||||
|
peft_config.target_parameters = sorted(expanded)
|
||||||
|
try:
|
||||||
|
return original_inject(
|
||||||
|
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
peft_config.target_parameters = original_targets
|
||||||
|
|
||||||
|
BaseTuner._inject_parameters = _patched_inject_parameters
|
||||||
|
patch_peft_target_parameters_matching._axolotl_patched = True
|
||||||
|
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
||||||
@@ -48,9 +48,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
):
|
):
|
||||||
# check if message_property_mappings is None or empty dict
|
# check if message_property_mappings is None or empty dict
|
||||||
if message_property_mappings is None or (not message_property_mappings):
|
if message_property_mappings is None or (not message_property_mappings):
|
||||||
default_message_property_mappings_keys = ["role", "content", "tool"]
|
|
||||||
message_property_mappings = {
|
message_property_mappings = {
|
||||||
prop: prop for prop in default_message_property_mappings_keys
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
}
|
}
|
||||||
if template_thinking_key and field_thinking:
|
if template_thinking_key and field_thinking:
|
||||||
message_property_mappings[template_thinking_key] = field_thinking
|
message_property_mappings[template_thinking_key] = field_thinking
|
||||||
|
|||||||
@@ -629,6 +629,17 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
quantize_moe_experts: bool = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Quantize MoE expert weights on load to reduce VRAM. "
|
||||||
|
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
|
||||||
|
"Requires CUDA (not compatible with ROCm or other backends). "
|
||||||
|
"Note: total parameter count may be reported incorrectly when enabled "
|
||||||
|
"(trainable param count is correct)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
scaling_softmax: bool | None = Field(
|
scaling_softmax: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1289,6 +1300,26 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_quantize_moe_experts(cls, data):
|
||||||
|
if data.get("quantize_moe_experts"):
|
||||||
|
if data.get("adapter") not in ("lora", "qlora"):
|
||||||
|
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
|
||||||
|
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
|
||||||
|
raise ValueError(
|
||||||
|
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
data.get("capabilities")
|
||||||
|
and data["capabilities"].get("compute_capability")
|
||||||
|
and not data["capabilities"]["compute_capability"].startswith("sm_")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_auto_enable_lora_kernels(cls, data):
|
def check_auto_enable_lora_kernels(cls, data):
|
||||||
|
|||||||
@@ -209,6 +209,19 @@ class LoraConfig(BaseModel):
|
|||||||
data["lora_dropout"] = 0.0
|
data["lora_dropout"] = 0.0
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_lora_target_parameters_dropout(self):
|
||||||
|
if (
|
||||||
|
self.lora_target_parameters
|
||||||
|
and self.lora_dropout
|
||||||
|
and self.lora_dropout != 0.0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_dropout must be 0 when lora_target_parameters is set. "
|
||||||
|
"PEFT's ParamWrapper does not support lora_dropout != 0."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
|
|||||||
142
tests/utils/schemas/validation/test_moe_quant.py
Normal file
142
tests/utils/schemas/validation/test_moe_quant.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def gpu_caps():
|
||||||
|
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def env_caps():
|
||||||
|
return {"torch_version": "2.7.0"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantizeMoeExpertsValidation:
|
||||||
|
"""Test suite for quantize_moe_experts config validator."""
|
||||||
|
|
||||||
|
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts without adapter should fail."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="requires adapter"):
|
||||||
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
|
||||||
|
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
|
||||||
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
|
||||||
|
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts with qlora + 4bit should pass."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is True
|
||||||
|
|
||||||
|
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts with lora + 8bit should pass."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=True,
|
||||||
|
adapter="lora",
|
||||||
|
load_in_8bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is True
|
||||||
|
|
||||||
|
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts=false should not check adapter/quantization."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
quantize_moe_experts=False,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is False
|
||||||
|
|
||||||
|
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
|
||||||
|
"""quantize_moe_experts should default to false."""
|
||||||
|
cfg = DictDefault({}) | min_base_cfg
|
||||||
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||||
|
assert result["quantize_moe_experts"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoraTargetParametersDropout:
|
||||||
|
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
|
||||||
|
|
||||||
|
def test_rejects_nonzero_dropout(self, min_base_cfg):
|
||||||
|
"""lora_dropout > 0 with lora_target_parameters should fail."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
adapter="lora",
|
||||||
|
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||||
|
lora_dropout=0.1,
|
||||||
|
load_in_8bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="lora_dropout must be 0"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_zero_dropout_passes(self, min_base_cfg):
|
||||||
|
"""lora_dropout=0 with lora_target_parameters should pass."""
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
adapter="lora",
|
||||||
|
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||||
|
lora_dropout=0.0,
|
||||||
|
load_in_8bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
result = validate_config(cfg)
|
||||||
|
assert result["lora_dropout"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPeftPatchIdempotency:
|
||||||
|
"""Test that patch_peft_target_parameters_matching is idempotent."""
|
||||||
|
|
||||||
|
def test_double_call_does_not_stack_wrappers(self):
|
||||||
|
"""Calling patch twice should not double-wrap _inject_parameters."""
|
||||||
|
from peft.tuners.tuners_utils import BaseTuner
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.moe_quant import (
|
||||||
|
patch_peft_target_parameters_matching,
|
||||||
|
)
|
||||||
|
|
||||||
|
original = BaseTuner._inject_parameters
|
||||||
|
try:
|
||||||
|
patch_peft_target_parameters_matching()
|
||||||
|
first_patched = BaseTuner._inject_parameters
|
||||||
|
patch_peft_target_parameters_matching()
|
||||||
|
second_patched = BaseTuner._inject_parameters
|
||||||
|
# Should be same function, not double-wrapped
|
||||||
|
assert first_patched is second_patched
|
||||||
|
finally:
|
||||||
|
BaseTuner._inject_parameters = original
|
||||||
|
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||||
Reference in New Issue
Block a user