Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f94ec0434c |
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -51,14 +51,6 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -181,14 +173,6 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
36
.github/workflows/tests.yml
vendored
36
.github/workflows/tests.yml
vendored
@@ -54,13 +54,13 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.11", "3.12"]
|
||||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||||
# exclude:
|
exclude:
|
||||||
# - python_version: "3.14"
|
- python_version: "3.12"
|
||||||
# pytorch_version: "2.8.0"
|
pytorch_version: "2.8.0"
|
||||||
# - python_version: "3.14"
|
- python_version: "3.12"
|
||||||
# pytorch_version: "2.9.1"
|
pytorch_version: "2.9.0"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -149,13 +149,13 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.11", "3.12"]
|
||||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||||
# exclude:
|
exclude:
|
||||||
# - python_version: "3.14"
|
- python_version: "3.12"
|
||||||
# pytorch_version: "2.8.0"
|
pytorch_version: "2.8.0"
|
||||||
# - python_version: "3.14"
|
- python_version: "3.12"
|
||||||
# pytorch_version: "2.9.1"
|
pytorch_version: "2.9.0"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -326,12 +326,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.10.0
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 130
|
- cuda: 130
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -377,7 +371,7 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- cuda: 129
|
- cuda: 129
|
||||||
cuda_version: 12.9.1
|
cuda_version: 12.9.1
|
||||||
python_version: "3.11"
|
python_version: "3.12"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
|
||||||
|
|
||||||
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# QLoRA
|
|
||||||
# - no target experts (1x48GB @ ~24GiB/GPU)
|
|
||||||
# - target experts (1x48GB @ ~34GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/qlora.yaml
|
|
||||||
|
|
||||||
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/qlora_fsdp.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# LoRA
|
|
||||||
# - no target experts (1x48GB @ ~35GiB/GPU)
|
|
||||||
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/lora.yaml
|
|
||||||
|
|
||||||
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
|
||||||
axolotl train examples/glm4.7-flash/lora_fsdp.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### Expert LoRA
|
|
||||||
|
|
||||||
To also apply LoRA adapters to expert weights, add `lora_target_parameters` to your config.
|
|
||||||
|
|
||||||
Note: `lora_dropout` must be `0` when using `lora_target_parameters`.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
lora_target_parameters:
|
|
||||||
- mlp.experts.gate_up_proj
|
|
||||||
- mlp.experts.down_proj
|
|
||||||
# - mlp.gate.weight # router, untested but should work, not normally targeted
|
|
||||||
```
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
- **FSDP VRAM**: FSDP2 may use more VRAM per GPU than single GPU training. We suspect not all layers are properly sharded across ranks.
|
|
||||||
- **FSDP initial spike**: FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps that then drops. FSDP QLoRA (4-bit) does not exhibit this.
|
|
||||||
- **cpu_ram_efficient_loading**: Must be set to `false` with FSDP2 — causes hang otherwise.
|
|
||||||
- **lora_target_linear**: Incompatible for this model.
|
|
||||||
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
|
||||||
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
|
||||||
- `temperature: 1.0`
|
|
||||||
- `top_p: 0.95`
|
|
||||||
- `max_new_tokens: 131072`
|
|
||||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
|
||||||
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
base_model: zai-org/GLM-4.7-Flash
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_4bit: true
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
# Uncomment to also target MoE expert weights:
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
# LoRA kernels incompatible with DSA attention
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 2
|
|
||||||
offload_params: false
|
|
||||||
cpu_ram_efficient_loading: false
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
sharding_strategy: FULL_SHARD
|
|
||||||
reshard_after_forward: true
|
|
||||||
activation_checkpointing: true
|
|
||||||
@@ -6,13 +6,30 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
|
||||||
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install Qwen3-Next transformers commit
|
||||||
|
```bash
|
||||||
|
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||||
|
```
|
||||||
|
|
||||||
3. Install FLA for improved performance
|
3. Install FLA for improved performance
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Run the finetuning example:
|
4. Run the finetuning example:
|
||||||
@@ -21,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
|||||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
|
This config uses about 45.62 GiB VRAM.
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ plugins:
|
|||||||
load_in_8bit: false
|
load_in_8bit: false
|
||||||
load_in_4bit: true
|
load_in_4bit: true
|
||||||
|
|
||||||
quantize_moe_experts: true
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
@@ -27,7 +25,7 @@ sample_packing: true
|
|||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 8
|
lora_alpha: 8
|
||||||
lora_dropout: 0
|
lora_dropout: 0.05
|
||||||
lora_target_modules:
|
lora_target_modules:
|
||||||
- linear_attn.in_proj_ba
|
- linear_attn.in_proj_ba
|
||||||
- linear_attn.in_proj_qkvz
|
- linear_attn.in_proj_qkvz
|
||||||
@@ -36,19 +34,12 @@ lora_target_modules:
|
|||||||
- shared_expert.down_proj
|
- shared_expert.down_proj
|
||||||
- shared_expert.gate_proj
|
- shared_expert.gate_proj
|
||||||
- shared_expert_gate
|
- shared_expert_gate
|
||||||
|
- mlp.gate
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
- k_proj
|
- k_proj
|
||||||
- o_proj
|
- o_proj
|
||||||
|
|
||||||
# lora_target_parameters:
|
|
||||||
# - mlp.experts.gate_up_proj
|
|
||||||
# - mlp.experts.down_proj
|
|
||||||
|
|
||||||
lora_mlp_kernel: false
|
|
||||||
lora_qkv_kernel: false
|
|
||||||
lora_o_kernel: false
|
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
|
|||||||
@@ -8,15 +8,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
2. Run the finetuning example:
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
This config uses about 24.9 GiB VRAM (w/o CCE).
|
This config uses about 24.9 GiB VRAM.
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
@@ -31,6 +29,10 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
base_model: arcee-ai/Trinity-Nano-Preview
|
base_model: arcee-ai/Trinity-Nano-Preview
|
||||||
|
trust_remote_code: true
|
||||||
revision_of_model: 2ee94b0
|
revision_of_model: 2ee94b0
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
|||||||
@@ -63,5 +63,3 @@ docstring-code-format = false
|
|||||||
|
|
||||||
[tool.uv.extra-build-dependencies]
|
[tool.uv.extra-build-dependencies]
|
||||||
axolotl = ["huggingface_hub"]
|
axolotl = ["huggingface_hub"]
|
||||||
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
|
||||||
deepspeed = [{ requirement = "torch", match-runtime = true }]
|
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,4 @@ MOE_ARCH_BLOCK = {
|
|||||||
"gpt_oss": "GptOssDecoderLayer",
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||||
"afmoe": "AfmoeMoE",
|
"afmoe": "AfmoeMoE",
|
||||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
|
||||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
|
||||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -720,16 +720,12 @@ class AxolotlTrainer(
|
|||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
|
||||||
# fix for Context Parallel save: CP eval invalidates tensor storage
|
# fix for Context Parallel save
|
||||||
# pointers, so clone to CPU to get fresh valid storage for safetensors
|
if state_dict is None:
|
||||||
if (
|
state_dict = self.accelerator.get_state_dict(self.model)
|
||||||
state_dict is not None
|
if state_dict is not None:
|
||||||
and self.axolotl_cfg
|
|
||||||
and self.axolotl_cfg.context_parallel_size
|
|
||||||
and self.axolotl_cfg.context_parallel_size > 1
|
|
||||||
):
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||||
for k, v in state_dict.items()
|
for k, v in state_dict.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -765,11 +761,7 @@ class AxolotlTrainer(
|
|||||||
metadata={"format": "pt"},
|
metadata={"format": "pt"},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.model.save_pretrained(
|
self.model.save_pretrained(output_dir, state_dict=state_dict)
|
||||||
output_dir,
|
|
||||||
state_dict=state_dict,
|
|
||||||
is_main_process=self.accelerator.is_main_process,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.processing_class is not None:
|
if self.processing_class is not None:
|
||||||
self.processing_class.save_pretrained(output_dir)
|
self.processing_class.save_pretrained(output_dir)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -88,9 +88,9 @@ plugins:
|
|||||||
- qwen2_vl
|
- qwen2_vl
|
||||||
- qwen3
|
- qwen3
|
||||||
- qwen3_5
|
- qwen3_5
|
||||||
- qwen3_5_text
|
|
||||||
- qwen3_5_moe
|
- qwen3_5_moe
|
||||||
- qwen3_5_moe_text
|
- qwen3_5_moe_vl
|
||||||
|
- qwen3_5_vl
|
||||||
- qwen3_moe
|
- qwen3_moe
|
||||||
- qwen3_next
|
- qwen3_next
|
||||||
- qwen3_vl
|
- qwen3_vl
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a668583"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -39,8 +39,6 @@ This works for any MoE model in transformers that uses a `SparseMoeBlock` class
|
|||||||
|
|
||||||
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
|
||||||
|
|
||||||
ScatterMoE does not work for GLM4.7 Flash (glm4_moe_lite) atm.
|
|
||||||
|
|
||||||
## Note on MegaBlocks
|
## Note on MegaBlocks
|
||||||
|
|
||||||
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.
|
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if isinstance(param, Params4bit) and param.quant_state is not None:
|
if isinstance(param, Params4bit):
|
||||||
param.quant_state._orig_to = param.quant_state.to
|
param.quant_state._orig_to = param.quant_state.to
|
||||||
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)
|
||||||
|
|
||||||
|
|||||||
@@ -172,10 +172,7 @@ class ModelLoader:
|
|||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||||
|
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
self.patch_manager.apply_post_model_build_patches(self.model)
|
|
||||||
|
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
# Post-build model configuration
|
# Post-build model configuration
|
||||||
@@ -863,10 +860,6 @@ class ModelLoader:
|
|||||||
# Make sure everything is in the same dtype
|
# Make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if getattr(self.model, "_moe_experts_quantized", False):
|
|
||||||
# Parametrized expert tensors dequantize on access — would OOM.
|
|
||||||
skip_prepare_model_for_kbit_training = True
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not skip_prepare_model_for_kbit_training
|
not skip_prepare_model_for_kbit_training
|
||||||
and self.cfg.adapter in ["lora", "qlora"]
|
and self.cfg.adapter in ["lora", "qlora"]
|
||||||
|
|||||||
@@ -118,7 +118,6 @@ class PatchManager:
|
|||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
self._apply_moe_expert_quantization_patch()
|
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
@@ -136,10 +135,6 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_prepare_context_parallel_inputs()
|
patch_prepare_context_parallel_inputs()
|
||||||
|
|
||||||
def apply_post_model_build_patches(self, model: PreTrainedModel):
|
|
||||||
"""Apply patches right after model build, before post-load setup."""
|
|
||||||
self._finalize_moe_expert_quantization(model)
|
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
self._apply_llama_flash_attn_patches(model)
|
self._apply_llama_flash_attn_patches(model)
|
||||||
@@ -175,14 +170,9 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_parallelism_config()
|
patch_parallelism_config()
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import (
|
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||||
patch_accelerate_fsdp2,
|
|
||||||
patch_tied_keys_for_meta_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_accelerate_fsdp2()
|
patch_accelerate_fsdp2()
|
||||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
|
||||||
patch_tied_keys_for_meta_device()
|
|
||||||
if self.cfg.rl:
|
if self.cfg.rl:
|
||||||
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
||||||
|
|
||||||
@@ -362,54 +352,15 @@ class PatchManager:
|
|||||||
if (
|
if (
|
||||||
self.cfg.fsdp_config
|
self.cfg.fsdp_config
|
||||||
and str(self.cfg.fsdp_version) == "2"
|
and str(self.cfg.fsdp_version) == "2"
|
||||||
and (self.cfg.load_in_4bit or self.cfg.load_in_8bit)
|
and self.cfg.adapter == "qlora"
|
||||||
):
|
):
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
apply_init_dtype_attrs_patch,
|
|
||||||
apply_init_sharded_param_patch,
|
apply_init_sharded_param_patch,
|
||||||
apply_init_unsharded_param_patch,
|
apply_init_unsharded_param_patch,
|
||||||
apply_linear8bitlt_save_patch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_init_sharded_param_patch()
|
apply_init_sharded_param_patch()
|
||||||
apply_init_unsharded_param_patch()
|
apply_init_unsharded_param_patch()
|
||||||
apply_init_dtype_attrs_patch()
|
|
||||||
if self.cfg.load_in_8bit:
|
|
||||||
apply_linear8bitlt_save_patch()
|
|
||||||
|
|
||||||
def _apply_moe_expert_quantization_patch(self):
|
|
||||||
"""Patch transformers weight loading to quantize MoE expert params on-the-fly."""
|
|
||||||
if not self.cfg.quantize_moe_experts:
|
|
||||||
return
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.moe_quant import (
|
|
||||||
patch_moe_quantization_on_load,
|
|
||||||
patch_peft_target_parameters_matching,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_moe_quantization_on_load(self.cfg)
|
|
||||||
patch_peft_target_parameters_matching()
|
|
||||||
|
|
||||||
def _finalize_moe_expert_quantization(self, model: PreTrainedModel):
|
|
||||||
"""Log quantization results and set model flag for downstream use."""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
model._moe_experts_quantized = False
|
|
||||||
if self.cfg.quantize_moe_experts:
|
|
||||||
from axolotl.monkeypatch.moe_quant import get_moe_quantized_count
|
|
||||||
|
|
||||||
count = get_moe_quantized_count()
|
|
||||||
if count > 0:
|
|
||||||
import gc
|
|
||||||
|
|
||||||
model._moe_experts_quantized = True
|
|
||||||
LOG.info(
|
|
||||||
"Quantized %d MoE expert parameter(s) to %s during model loading",
|
|
||||||
count,
|
|
||||||
"4-bit" if self.cfg.load_in_4bit else "8-bit",
|
|
||||||
)
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
|
|||||||
@@ -111,7 +111,6 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
|
|||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
state_dict: Optional[dict] = None,
|
state_dict: Optional[dict] = None,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|||||||
@@ -150,17 +150,13 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
)
|
)
|
||||||
elif self.is_fsdp2:
|
elif self.is_fsdp2:
|
||||||
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
# https://github.com/pytorch/torchtune/blob/main/torchtune/training/_distributed.py#L465
|
||||||
from torch.distributed.tensor import DTensor
|
|
||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
sharded_state_dict = model.state_dict()
|
sharded_state_dict = model.state_dict()
|
||||||
for param_name, param in sharded_state_dict.items():
|
for param_name, param in sharded_state_dict.items():
|
||||||
if param.is_cpu:
|
if param.is_cpu:
|
||||||
param = param.to(torch.device("cuda"))
|
param = param.to(torch.device("cuda"))
|
||||||
|
|
||||||
if isinstance(param, DTensor):
|
param = param.full_tensor()
|
||||||
param = param.full_tensor()
|
|
||||||
|
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
state_dict[param_name] = param.cpu()
|
state_dict[param_name] = param.cpu()
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
@@ -186,56 +182,10 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def patch_peft_param_wrapper_for_fsdp2():
|
|
||||||
"""Patch PEFT's _LoraParameterProxy.forward for FSDP2 DTensor compatibility.
|
|
||||||
|
|
||||||
PEFT's ParamWrapper applies LoRA via torch.nn.utils.parametrize, which adds
|
|
||||||
delta_weight to the base weight W inside _LoraParameterProxy.forward().
|
|
||||||
Under FSDP2, W may be a DTensor (from FSDP unshard) while delta_weight is a
|
|
||||||
regular Tensor (or vice versa), causing a RuntimeError on mixed types.
|
|
||||||
|
|
||||||
This patch promotes the non-DTensor operand to match the DTensor's spec
|
|
||||||
using DTensor.from_local(), which is free for Replicate placement (just
|
|
||||||
metadata wrapping, no communication).
|
|
||||||
"""
|
|
||||||
from peft.tuners.lora.layer import _LoraParameterProxy
|
|
||||||
|
|
||||||
if getattr(_LoraParameterProxy, "_axolotl_fsdp2_patched", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
_original_forward = _LoraParameterProxy.forward
|
|
||||||
|
|
||||||
# NOTE: Replaces (not wraps) forward; assumes original is just `W + self.delta_weight`.
|
|
||||||
def _patched_forward(self, W):
|
|
||||||
from torch.distributed.tensor import DTensor
|
|
||||||
|
|
||||||
delta = self.delta_weight
|
|
||||||
w_is_dt = isinstance(W, DTensor)
|
|
||||||
d_is_dt = isinstance(delta, DTensor)
|
|
||||||
|
|
||||||
with torch.nn.utils.parametrize.cached():
|
|
||||||
if w_is_dt == d_is_dt:
|
|
||||||
return W + delta
|
|
||||||
if w_is_dt:
|
|
||||||
return W + DTensor.from_local(delta, W.device_mesh, W.placements)
|
|
||||||
return DTensor.from_local(W, delta.device_mesh, delta.placements) + delta
|
|
||||||
|
|
||||||
_LoraParameterProxy.forward = _patched_forward
|
|
||||||
_LoraParameterProxy._axolotl_fsdp2_patched = True
|
|
||||||
LOG.info("Patched PEFT _LoraParameterProxy.forward for FSDP2 DTensor compatibility")
|
|
||||||
|
|
||||||
|
|
||||||
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||||
"""Helper function to process LoRA modules for FSDP2."""
|
"""Helper function to process LoRA modules for FSDP2."""
|
||||||
from peft.tuners.lora.layer import ParamWrapper
|
|
||||||
from torch.distributed.fsdp import fully_shard
|
from torch.distributed.fsdp import fully_shard
|
||||||
|
|
||||||
# Skip ParamWrapper — its lora_A/B must not be independently sharded.
|
|
||||||
# The parent decoder layer's FSDP wrapper handles unsharding them.
|
|
||||||
# TODO: review if we even need to shard them separately in first place.
|
|
||||||
if isinstance(module, ParamWrapper):
|
|
||||||
return False
|
|
||||||
|
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
|
|
||||||
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
@@ -377,14 +327,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
|
|
||||||
is_peft_model = isinstance(model, PeftModel)
|
is_peft_model = isinstance(model, PeftModel)
|
||||||
|
|
||||||
# Patch PEFT's _LoraParameterProxy for DTensor compatibility if any
|
|
||||||
# ParamWrapper modules exist (used for target_parameters / 3D expert params).
|
|
||||||
if is_peft_model:
|
|
||||||
from peft.tuners.lora.layer import ParamWrapper
|
|
||||||
|
|
||||||
if any(isinstance(m, ParamWrapper) for m in model.modules()):
|
|
||||||
patch_peft_param_wrapper_for_fsdp2()
|
|
||||||
|
|
||||||
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||||
log_bias_dtype_mismatch = False
|
log_bias_dtype_mismatch = False
|
||||||
if auto_wrap_policy is not None:
|
if auto_wrap_policy is not None:
|
||||||
@@ -434,43 +376,6 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def patch_tied_keys_for_meta_device():
|
|
||||||
"""Patch _adjust_tied_keys_with_tied_pointers to skip meta tensors.
|
|
||||||
|
|
||||||
Meta tensors all share data_ptr()==0, causing every parameter to be incorrectly
|
|
||||||
grouped as "tied". Skipping them is safe since they have no real storage.
|
|
||||||
"""
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
def _patched_adjust_tied_keys_with_tied_pointers(self, missing_keys):
|
|
||||||
param_pointers = defaultdict(list)
|
|
||||||
for param_name, param_value in self.state_dict().items():
|
|
||||||
if param_value.is_meta:
|
|
||||||
continue
|
|
||||||
param_pointers[param_value.data_ptr()].append(param_name)
|
|
||||||
|
|
||||||
tied_param_names = [
|
|
||||||
names
|
|
||||||
for names in param_pointers.values()
|
|
||||||
if len(names) > 1
|
|
||||||
and not any(name in self.all_tied_weights_keys.keys() for name in names)
|
|
||||||
and not all(name in missing_keys for name in names)
|
|
||||||
]
|
|
||||||
|
|
||||||
tied_weights_keys_by_pointers = {
|
|
||||||
param_name: group[0]
|
|
||||||
for group in tied_param_names
|
|
||||||
for param_name in group[1:]
|
|
||||||
}
|
|
||||||
self.all_tied_weights_keys.update(tied_weights_keys_by_pointers)
|
|
||||||
|
|
||||||
PreTrainedModel._adjust_tied_keys_with_tied_pointers = (
|
|
||||||
_patched_adjust_tied_keys_with_tied_pointers
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_accelerate_fsdp2():
|
def patch_accelerate_fsdp2():
|
||||||
import accelerate
|
import accelerate
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Monkeypatch to add Params4bit and Int8Params support to FSDP2. This enables QLoRA + FSDP2
|
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
||||||
and 8-bit LoRA + FSDP2, as well as our LoRA / QLoRA Triton kernels to work with FSDP2.
|
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||||
|
|
||||||
This patch modifies the _init_sharded_param and init_unsharded_param methods in FSDPParam
|
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
||||||
to handle bitsandbytes Params4bit and Int8Params parameters, preserving their quantization
|
Params4bit parameters.
|
||||||
metadata through the FSDP2 shard/unshard cycle.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
@@ -18,8 +17,6 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
def apply_init_sharded_param_patch():
|
def apply_init_sharded_param_patch():
|
||||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||||
if getattr(apply_init_sharded_param_patch, "_axolotl_patched", False):
|
|
||||||
return
|
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
# Get original source
|
# Get original source
|
||||||
@@ -44,20 +41,9 @@ def apply_init_sharded_param_patch():
|
|||||||
bnb_quantized=param.bnb_quantized,
|
bnb_quantized=param.bnb_quantized,
|
||||||
)
|
)
|
||||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
elif isinstance(param, bnb.nn.modules.Int8Params):
|
|
||||||
self.sharded_param = bnb.nn.modules.Int8Params(
|
|
||||||
data=sharded_param,
|
|
||||||
requires_grad=param.requires_grad,
|
|
||||||
has_fp16_weights=param.has_fp16_weights,
|
|
||||||
CB=None,
|
|
||||||
SCB=param.SCB,
|
|
||||||
)
|
|
||||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
|
||||||
else:
|
else:
|
||||||
self.sharded_param = nn.Parameter(
|
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||||
self.to_sharded_dtensor(sharded_param),
|
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||||
requires_grad=param.requires_grad,
|
|
||||||
)"""
|
|
||||||
|
|
||||||
# Apply the replacement
|
# Apply the replacement
|
||||||
if original_param_creation in original_source:
|
if original_param_creation in original_source:
|
||||||
@@ -87,7 +73,6 @@ def apply_init_sharded_param_patch():
|
|||||||
|
|
||||||
# Replace the method
|
# Replace the method
|
||||||
FSDPParam._init_sharded_param = patched_init_sharded_param
|
FSDPParam._init_sharded_param = patched_init_sharded_param
|
||||||
apply_init_sharded_param_patch._axolotl_patched = True
|
|
||||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||||
else:
|
else:
|
||||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||||
@@ -95,8 +80,6 @@ def apply_init_sharded_param_patch():
|
|||||||
|
|
||||||
def apply_init_unsharded_param_patch():
|
def apply_init_unsharded_param_patch():
|
||||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||||
if getattr(apply_init_unsharded_param_patch, "_axolotl_patched", False):
|
|
||||||
return
|
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
# Get original source
|
# Get original source
|
||||||
@@ -122,14 +105,6 @@ def apply_init_unsharded_param_patch():
|
|||||||
module=local_tensor.module,
|
module=local_tensor.module,
|
||||||
bnb_quantized=local_tensor.bnb_quantized,
|
bnb_quantized=local_tensor.bnb_quantized,
|
||||||
)
|
)
|
||||||
elif isinstance(local_tensor, bnb.nn.modules.Int8Params):
|
|
||||||
self._unsharded_param = bnb.nn.modules.Int8Params(
|
|
||||||
data=unsharded_param,
|
|
||||||
requires_grad=self.sharded_param.requires_grad,
|
|
||||||
has_fp16_weights=local_tensor.has_fp16_weights,
|
|
||||||
CB=unsharded_param,
|
|
||||||
SCB=local_tensor.SCB,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self._unsharded_param = nn.Parameter(
|
self._unsharded_param = nn.Parameter(
|
||||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||||
@@ -163,74 +138,6 @@ def apply_init_unsharded_param_patch():
|
|||||||
|
|
||||||
# Replace the method
|
# Replace the method
|
||||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
FSDPParam.init_unsharded_param = patched_init_unsharded_param
|
||||||
apply_init_unsharded_param_patch._axolotl_patched = True
|
|
||||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||||
else:
|
else:
|
||||||
LOG.warning("Could not find target code for patching")
|
LOG.warning("Could not find target code for patching")
|
||||||
|
|
||||||
|
|
||||||
def apply_linear8bitlt_save_patch():
|
|
||||||
"""Patch Linear8bitLt._save_to_state_dict to handle DTensor-wrapped Int8Params.
|
|
||||||
|
|
||||||
After FSDP2 sharding, Linear8bitLt.weight is a DTensor wrapping Int8Params.
|
|
||||||
BnB's _save_to_state_dict accesses self.weight.SCB directly, but DTensor
|
|
||||||
doesn't proxy custom attribute access to its _local_tensor. This patch
|
|
||||||
temporarily unwraps the DTensor during saving so BnB can find the SCB attribute.
|
|
||||||
"""
|
|
||||||
if getattr(apply_linear8bitlt_save_patch, "_axolotl_patched", False):
|
|
||||||
return
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
from torch.distributed.tensor import DTensor
|
|
||||||
|
|
||||||
original_save = bnb.nn.Linear8bitLt._save_to_state_dict
|
|
||||||
|
|
||||||
def _patched_save_to_state_dict(self, destination, prefix, keep_vars):
|
|
||||||
# Use _parameters dict directly to bypass nn.Module.__setattr__ type check.
|
|
||||||
weight = self._parameters["weight"]
|
|
||||||
unwrapped = False
|
|
||||||
if isinstance(weight, DTensor) and hasattr(weight, "_local_tensor"):
|
|
||||||
self._parameters["weight"] = weight._local_tensor
|
|
||||||
unwrapped = True
|
|
||||||
try:
|
|
||||||
original_save(self, destination, prefix, keep_vars)
|
|
||||||
finally:
|
|
||||||
if unwrapped:
|
|
||||||
self._parameters["weight"] = weight
|
|
||||||
|
|
||||||
bnb.nn.Linear8bitLt._save_to_state_dict = _patched_save_to_state_dict
|
|
||||||
apply_linear8bitlt_save_patch._axolotl_patched = True
|
|
||||||
LOG.info("Patched Linear8bitLt._save_to_state_dict for DTensor compatibility")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_init_dtype_attrs_patch():
|
|
||||||
"""Prevent FSDP2 mixed precision from casting non-float quantized params.
|
|
||||||
|
|
||||||
When mixed precision is enabled (e.g., bf16), FSDP2's init_dtype_attrs sets
|
|
||||||
param_dtype=bf16 for ALL params. During all-gather, _to_dtype_if_needed casts
|
|
||||||
the sharded param to param_dtype. For non-float params (uint8 packed 4-bit,
|
|
||||||
int8 quantized) without FSDP2 extensions, this destroys the quantized data.
|
|
||||||
|
|
||||||
Params4bit handles this via fsdp_pre/post_all_gather extensions, but our
|
|
||||||
parametrize-based expert quantization uses plain nn.Parameter(uint8/int8)
|
|
||||||
without extensions.
|
|
||||||
"""
|
|
||||||
if getattr(apply_init_dtype_attrs_patch, "_axolotl_patched", False):
|
|
||||||
return
|
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
|
||||||
|
|
||||||
original_init_dtype_attrs = FSDPParam.init_dtype_attrs
|
|
||||||
|
|
||||||
def patched_init_dtype_attrs(self, mp_policy):
|
|
||||||
original_init_dtype_attrs(self, mp_policy)
|
|
||||||
# Skip casting non-float quantized params (uint8/int8) without FSDP2
|
|
||||||
# extensions — the parametrization chain handles dequantization.
|
|
||||||
if self.param_dtype is not None and not self.sharded_param.is_floating_point():
|
|
||||||
local = self.sharded_param
|
|
||||||
if hasattr(local, "_local_tensor"):
|
|
||||||
local = local._local_tensor
|
|
||||||
if not hasattr(local, "fsdp_pre_all_gather"):
|
|
||||||
self.param_dtype = None
|
|
||||||
|
|
||||||
FSDPParam.init_dtype_attrs = patched_init_dtype_attrs
|
|
||||||
apply_init_dtype_attrs_patch._axolotl_patched = True
|
|
||||||
LOG.info("Patched FSDPParam.init_dtype_attrs for non-float quantized params")
|
|
||||||
|
|||||||
@@ -9,11 +9,6 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
|
||||||
except ImportError:
|
|
||||||
fla_causal_conv1d = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_cu_seqlens(position_ids):
|
def get_cu_seqlens(position_ids):
|
||||||
"""
|
"""
|
||||||
@@ -142,11 +137,6 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
and cache_position is not None
|
and cache_position is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
|
|
||||||
cu_seqlens = None
|
|
||||||
if not use_precomputed_states and position_ids is not None:
|
|
||||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
|
||||||
|
|
||||||
# getting projected states from cache if it exists
|
# getting projected states from cache if it exists
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
conv_state = cache_params.conv_states[self.layer_idx]
|
conv_state = cache_params.conv_states[self.layer_idx]
|
||||||
@@ -161,11 +151,12 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
|
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||||
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
|
|
||||||
if use_precomputed_states:
|
if use_precomputed_states:
|
||||||
# Inference single-token path: causal_conv1d_update expects [B, D, T]
|
# 2. Convolution sequence transformation
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
# NOTE: the conv state is updated in `causal_conv1d_update`
|
||||||
mixed_qkv = self.causal_conv1d_update(
|
mixed_qkv = self.causal_conv1d_update(
|
||||||
mixed_qkv,
|
mixed_qkv,
|
||||||
conv_state,
|
conv_state,
|
||||||
@@ -173,41 +164,24 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
self.conv1d.bias,
|
self.conv1d.bias,
|
||||||
self.activation,
|
self.activation,
|
||||||
)
|
)
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
|
||||||
else:
|
else:
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
# Cache state expects [B, D, T] for the inference update path
|
|
||||||
mixed_qkv_t = mixed_qkv.transpose(1, 2)
|
|
||||||
conv_state = F.pad(
|
conv_state = F.pad(
|
||||||
mixed_qkv_t,
|
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
||||||
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
|
|
||||||
)
|
)
|
||||||
cache_params.conv_states[self.layer_idx] = conv_state
|
cache_params.conv_states[self.layer_idx] = conv_state
|
||||||
|
if self.causal_conv1d_fn is not None:
|
||||||
if fla_causal_conv1d is not None:
|
mixed_qkv = self.causal_conv1d_fn(
|
||||||
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
|
|
||||||
mixed_qkv, _ = fla_causal_conv1d(
|
|
||||||
x=mixed_qkv,
|
x=mixed_qkv,
|
||||||
weight=self.conv1d.weight.squeeze(1),
|
weight=self.conv1d.weight.squeeze(1),
|
||||||
bias=self.conv1d.bias,
|
bias=self.conv1d.bias,
|
||||||
activation=self.activation,
|
activation=self.activation,
|
||||||
cu_seqlens=cu_seqlens,
|
seq_idx=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# PyTorch fallback (no cu_seqlens support)
|
|
||||||
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Packed sequences require fla.modules.convolution.causal_conv1d "
|
|
||||||
"(cu_seqlens support). Install flash-linear-attention or disable packing."
|
|
||||||
)
|
|
||||||
LOG.warning_once(
|
|
||||||
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
|
|
||||||
)
|
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
|
||||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
|
||||||
|
|
||||||
# mixed_qkv is [B, T, D] in all paths
|
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||||
query, key, value = torch.split(
|
query, key, value = torch.split(
|
||||||
mixed_qkv,
|
mixed_qkv,
|
||||||
[
|
[
|
||||||
@@ -229,6 +203,7 @@ def patch_qwen3_next_gateddelta_layer():
|
|||||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||||
|
|
||||||
if not use_precomputed_states:
|
if not use_precomputed_states:
|
||||||
|
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||||
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
@@ -1,188 +0,0 @@
|
|||||||
"""
|
|
||||||
Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.
|
|
||||||
|
|
||||||
In transformers v5, MoE models store expert weights as fused 3D tensors that BnB
|
|
||||||
skips (only targets nn.Linear). This module patches weight loading to quantize them
|
|
||||||
on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization),
|
|
||||||
reducing peak VRAM from "all experts in bf16" to "one expert at a time."
|
|
||||||
"""
|
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
import torch
|
|
||||||
import torch.nn.utils.parametrize as P
|
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
# Module-level state for the loading-time quantization patch.
|
|
||||||
_moe_load_state = {
|
|
||||||
"count": 0,
|
|
||||||
"mode": "4bit",
|
|
||||||
"quant_type": "nf4",
|
|
||||||
"compress_statistics": True,
|
|
||||||
"patched": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Bnb8bitParametrization(torch.nn.Module):
|
|
||||||
"""Parametrization that dequantizes int8 row-wise quantized data on access."""
|
|
||||||
|
|
||||||
def __init__(self, row_stats: torch.Tensor):
|
|
||||||
super().__init__()
|
|
||||||
self.register_buffer("row_stats", row_stats)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
|
|
||||||
# Flatten 3D+ to 2D for BnB's dequant, then reshape back.
|
|
||||||
orig_shape = quantized_param.shape
|
|
||||||
if quantized_param.ndim > 2:
|
|
||||||
quantized_param = quantized_param.reshape(-1, orig_shape[-1])
|
|
||||||
result = bnb.functional.int8_vectorwise_dequant(quantized_param, self.row_stats)
|
|
||||||
return result.reshape(orig_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def _enable_parametrization_cache(module, inputs):
|
|
||||||
P._cache_enabled += 1
|
|
||||||
|
|
||||||
|
|
||||||
def _disable_parametrization_cache(module, inputs, output):
|
|
||||||
P._cache_enabled -= 1
|
|
||||||
if not P._cache_enabled:
|
|
||||||
P._cache = {}
|
|
||||||
|
|
||||||
|
|
||||||
def replace_parameter_8bit(module, param_name):
|
|
||||||
"""Replace a module parameter with an 8-bit quantized version using parametrization."""
|
|
||||||
original_param = getattr(module, param_name)
|
|
||||||
int8_data, row_stats, _ = bnb.functional.int8_vectorwise_quant(
|
|
||||||
original_param.data.to(torch.float16)
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(module, param_name, torch.nn.Parameter(int8_data, requires_grad=False))
|
|
||||||
del original_param
|
|
||||||
|
|
||||||
P.register_parametrization(
|
|
||||||
module, param_name, Bnb8bitParametrization(row_stats), unsafe=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cache dequantized values during forward to avoid redundant dequantization.
|
|
||||||
if not getattr(module, "_axolotl_8bit_hooks_registered", False):
|
|
||||||
module.register_forward_pre_hook(_enable_parametrization_cache)
|
|
||||||
module.register_forward_hook(_disable_parametrization_cache)
|
|
||||||
module._axolotl_8bit_hooks_registered = True
|
|
||||||
|
|
||||||
|
|
||||||
def patch_moe_quantization_on_load(cfg):
|
|
||||||
"""Patch transformers' weight loading to quantize MoE expert params on-the-fly.
|
|
||||||
|
|
||||||
Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their
|
|
||||||
name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low.
|
|
||||||
"""
|
|
||||||
mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit"
|
|
||||||
_moe_load_state["mode"] = mode
|
|
||||||
_moe_load_state["count"] = 0
|
|
||||||
|
|
||||||
if _moe_load_state["patched"]:
|
|
||||||
LOG.debug("MoE loading-time quantization patch already active")
|
|
||||||
return
|
|
||||||
|
|
||||||
import transformers.core_model_loading
|
|
||||||
import transformers.modeling_utils
|
|
||||||
|
|
||||||
if mode == "4bit":
|
|
||||||
from bitsandbytes.nn.parametrize import replace_parameter_4bit
|
|
||||||
|
|
||||||
quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4"
|
|
||||||
compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None)
|
|
||||||
if compress_statistics is None:
|
|
||||||
compress_statistics = True
|
|
||||||
|
|
||||||
_moe_load_state["quant_type"] = quant_type
|
|
||||||
_moe_load_state["compress_statistics"] = compress_statistics
|
|
||||||
|
|
||||||
# Disable caching_allocator_warmup — it pre-allocates a huge tensor at bf16
|
|
||||||
# size for all params, defeating our on-load quantization VRAM savings.
|
|
||||||
def _noop_warmup(*args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
transformers.modeling_utils.caching_allocator_warmup = _noop_warmup
|
|
||||||
|
|
||||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
|
||||||
|
|
||||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
|
||||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
|
||||||
|
|
||||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
|
||||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
|
||||||
mod_path, _, pname = target_name.rpartition(".")
|
|
||||||
mod = model.get_submodule(mod_path) if mod_path else model
|
|
||||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
|
||||||
if "expert" not in target_name.lower():
|
|
||||||
LOG.debug(
|
|
||||||
"Skipping non-expert 3D param: %s (shape=%s)",
|
|
||||||
target_name,
|
|
||||||
list(param_value.shape),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if _moe_load_state["mode"] == "4bit":
|
|
||||||
replace_parameter_4bit(
|
|
||||||
mod,
|
|
||||||
pname,
|
|
||||||
compress_statistics=_moe_load_state["compress_statistics"],
|
|
||||||
quant_type=_moe_load_state["quant_type"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
replace_parameter_8bit(mod, pname)
|
|
||||||
_moe_load_state["count"] += 1
|
|
||||||
|
|
||||||
# Release the bf16 tensor so CUDA memory is freed immediately.
|
|
||||||
param_value.data = torch.empty(0, device="cpu")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
|
||||||
_moe_load_state["patched"] = True
|
|
||||||
|
|
||||||
|
|
||||||
def get_moe_quantized_count():
|
|
||||||
"""Return the number of expert parameters quantized during loading."""
|
|
||||||
return _moe_load_state["count"]
|
|
||||||
|
|
||||||
|
|
||||||
def patch_peft_target_parameters_matching():
|
|
||||||
"""Fix PEFT's _inject_parameters to use suffix matching for parametrized modules."""
|
|
||||||
if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False):
|
|
||||||
return
|
|
||||||
from peft.tuners.tuners_utils import BaseTuner
|
|
||||||
|
|
||||||
original_inject = BaseTuner._inject_parameters
|
|
||||||
|
|
||||||
def _patched_inject_parameters(
|
|
||||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
|
||||||
):
|
|
||||||
# Patch target_parameters to use full paths for parametrized modules
|
|
||||||
original_targets = list(peft_config.target_parameters)
|
|
||||||
expanded = set(original_targets)
|
|
||||||
|
|
||||||
for module_name, module in model.named_modules():
|
|
||||||
if not hasattr(module, "parametrizations"):
|
|
||||||
continue
|
|
||||||
for target in original_targets:
|
|
||||||
mod_path, _, param_name = target.rpartition(".")
|
|
||||||
if (
|
|
||||||
module_name == mod_path or module_name.endswith("." + mod_path)
|
|
||||||
) and hasattr(module, param_name):
|
|
||||||
expanded.add(f"{module_name}.{param_name}")
|
|
||||||
|
|
||||||
peft_config.target_parameters = sorted(expanded)
|
|
||||||
try:
|
|
||||||
return original_inject(
|
|
||||||
self, peft_config, model, adapter_name, low_cpu_mem_usage
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
peft_config.target_parameters = original_targets
|
|
||||||
|
|
||||||
BaseTuner._inject_parameters = _patched_inject_parameters
|
|
||||||
patch_peft_target_parameters_matching._axolotl_patched = True
|
|
||||||
LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching")
|
|
||||||
@@ -48,9 +48,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
):
|
):
|
||||||
# check if message_property_mappings is None or empty dict
|
# check if message_property_mappings is None or empty dict
|
||||||
if message_property_mappings is None or (not message_property_mappings):
|
if message_property_mappings is None or (not message_property_mappings):
|
||||||
|
default_message_property_mappings_keys = ["role", "content", "tool"]
|
||||||
message_property_mappings = {
|
message_property_mappings = {
|
||||||
"role": "role",
|
prop: prop for prop in default_message_property_mappings_keys
|
||||||
"content": "content",
|
|
||||||
}
|
}
|
||||||
if template_thinking_key and field_thinking:
|
if template_thinking_key and field_thinking:
|
||||||
message_property_mappings[template_thinking_key] = field_thinking
|
message_property_mappings[template_thinking_key] = field_thinking
|
||||||
|
|||||||
@@ -629,17 +629,6 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
quantize_moe_experts: bool = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Quantize MoE expert weights on load to reduce VRAM. "
|
|
||||||
"Requires adapter (lora/qlora) with load_in_4bit or load_in_8bit. "
|
|
||||||
"Requires CUDA (not compatible with ROCm or other backends). "
|
|
||||||
"Note: total parameter count may be reported incorrectly when enabled "
|
|
||||||
"(trainable param count is correct)."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
scaling_softmax: bool | None = Field(
|
scaling_softmax: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1300,26 +1289,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_quantize_moe_experts(cls, data):
|
|
||||||
if data.get("quantize_moe_experts"):
|
|
||||||
if data.get("adapter") not in ("lora", "qlora"):
|
|
||||||
raise ValueError("quantize_moe_experts requires adapter: lora or qlora")
|
|
||||||
if not (data.get("load_in_4bit") or data.get("load_in_8bit")):
|
|
||||||
raise ValueError(
|
|
||||||
"quantize_moe_experts requires load_in_4bit or load_in_8bit"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
data.get("capabilities")
|
|
||||||
and data["capabilities"].get("compute_capability")
|
|
||||||
and not data["capabilities"]["compute_capability"].startswith("sm_")
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"quantize_moe_experts requires CUDA (not compatible with ROCm or other backends)"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_auto_enable_lora_kernels(cls, data):
|
def check_auto_enable_lora_kernels(cls, data):
|
||||||
|
|||||||
@@ -209,19 +209,6 @@ class LoraConfig(BaseModel):
|
|||||||
data["lora_dropout"] = 0.0
|
data["lora_dropout"] = 0.0
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_lora_target_parameters_dropout(self):
|
|
||||||
if (
|
|
||||||
self.lora_target_parameters
|
|
||||||
and self.lora_dropout
|
|
||||||
and self.lora_dropout != 0.0
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"lora_dropout must be 0 when lora_target_parameters is set. "
|
|
||||||
"PEFT's ParamWrapper does not support lora_dropout != 0."
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
|
|||||||
@@ -1,142 +0,0 @@
|
|||||||
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def gpu_caps():
|
|
||||||
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def env_caps():
|
|
||||||
return {"torch_version": "2.7.0"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizeMoeExpertsValidation:
|
|
||||||
"""Test suite for quantize_moe_experts config validator."""
|
|
||||||
|
|
||||||
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
|
|
||||||
"""quantize_moe_experts without adapter should fail."""
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
quantize_moe_experts=True,
|
|
||||||
)
|
|
||||||
| min_base_cfg
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError, match="requires adapter"):
|
|
||||||
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
||||||
|
|
||||||
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
|
|
||||||
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
quantize_moe_experts=True,
|
|
||||||
adapter="lora",
|
|
||||||
)
|
|
||||||
| min_base_cfg
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
|
|
||||||
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
||||||
|
|
||||||
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
|
|
||||||
"""quantize_moe_experts with qlora + 4bit should pass."""
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
quantize_moe_experts=True,
|
|
||||||
adapter="qlora",
|
|
||||||
load_in_4bit=True,
|
|
||||||
)
|
|
||||||
| min_base_cfg
|
|
||||||
)
|
|
||||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
||||||
assert result["quantize_moe_experts"] is True
|
|
||||||
|
|
||||||
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
|
|
||||||
"""quantize_moe_experts with lora + 8bit should pass."""
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
quantize_moe_experts=True,
|
|
||||||
adapter="lora",
|
|
||||||
load_in_8bit=True,
|
|
||||||
)
|
|
||||||
| min_base_cfg
|
|
||||||
)
|
|
||||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
||||||
assert result["quantize_moe_experts"] is True
|
|
||||||
|
|
||||||
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
|
|
||||||
"""quantize_moe_experts=false should not check adapter/quantization."""
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
quantize_moe_experts=False,
|
|
||||||
)
|
|
||||||
| min_base_cfg
|
|
||||||
)
|
|
||||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
||||||
assert result["quantize_moe_experts"] is False
|
|
||||||
|
|
||||||
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
|
|
||||||
"""quantize_moe_experts should default to false."""
|
|
||||||
cfg = DictDefault({}) | min_base_cfg
|
|
||||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
||||||
assert result["quantize_moe_experts"] is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoraTargetParametersDropout:
|
|
||||||
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
|
|
||||||
|
|
||||||
def test_rejects_nonzero_dropout(self, min_base_cfg):
|
|
||||||
"""lora_dropout > 0 with lora_target_parameters should fail."""
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
adapter="lora",
|
|
||||||
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
|
||||||
lora_dropout=0.1,
|
|
||||||
load_in_8bit=True,
|
|
||||||
)
|
|
||||||
| min_base_cfg
|
|
||||||
)
|
|
||||||
with pytest.raises(ValueError, match="lora_dropout must be 0"):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
def test_zero_dropout_passes(self, min_base_cfg):
|
|
||||||
"""lora_dropout=0 with lora_target_parameters should pass."""
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
adapter="lora",
|
|
||||||
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
|
||||||
lora_dropout=0.0,
|
|
||||||
load_in_8bit=True,
|
|
||||||
)
|
|
||||||
| min_base_cfg
|
|
||||||
)
|
|
||||||
result = validate_config(cfg)
|
|
||||||
assert result["lora_dropout"] == 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class TestPeftPatchIdempotency:
|
|
||||||
"""Test that patch_peft_target_parameters_matching is idempotent."""
|
|
||||||
|
|
||||||
def test_double_call_does_not_stack_wrappers(self):
|
|
||||||
"""Calling patch twice should not double-wrap _inject_parameters."""
|
|
||||||
from peft.tuners.tuners_utils import BaseTuner
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.moe_quant import (
|
|
||||||
patch_peft_target_parameters_matching,
|
|
||||||
)
|
|
||||||
|
|
||||||
original = BaseTuner._inject_parameters
|
|
||||||
try:
|
|
||||||
patch_peft_target_parameters_matching()
|
|
||||||
first_patched = BaseTuner._inject_parameters
|
|
||||||
patch_peft_target_parameters_matching()
|
|
||||||
second_patched = BaseTuner._inject_parameters
|
|
||||||
# Should be same function, not double-wrapped
|
|
||||||
assert first_patched is second_patched
|
|
||||||
finally:
|
|
||||||
BaseTuner._inject_parameters = original
|
|
||||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
|
||||||
Reference in New Issue
Block a user