diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml
index 6180faf96..f58c05f3b 100644
--- a/.github/workflows/multi-gpu-e2e.yml
+++ b/.github/workflows/multi-gpu-e2e.yml
@@ -33,6 +33,13 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
+ - cuda: 126
+ cuda_version: 12.6.3
+ python_version: "3.11"
+ pytorch: 2.7.0
+ axolotl_extras: vllm
+ num_gpus: 2
+ nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
diff --git a/.github/workflows/nightlies.yml b/.github/workflows/nightlies.yml
index 824c7e4f2..49bce470b 100644
--- a/.github/workflows/nightlies.yml
+++ b/.github/workflows/nightlies.yml
@@ -12,11 +12,16 @@ jobs:
fail-fast: false
matrix:
include:
- - cuda: 124
- cuda_version: 12.4.1
+ - cuda: 126
+ cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
+ - cuda: 126
+ cuda_version: 12.6.3
+ python_version: "3.11"
+ pytorch: 2.7.1
+ axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -60,15 +65,15 @@ jobs:
strategy:
matrix:
include:
- - cuda: 124
- cuda_version: 12.4.1
+ - cuda: 126
+ cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
- pytorch: 2.6.0
+ pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml
index b5dd50a3c..54d734e49 100644
--- a/.github/workflows/tests-nightly.yml
+++ b/.github/workflows/tests-nightly.yml
@@ -92,7 +92,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
- timeout-minutes: 60
+ timeout-minutes: 120
needs: [pre-commit, pytest]
strategy:
@@ -116,7 +116,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
- pip install modal==0.71.8 jinja2
+ pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
diff --git a/_quarto.yml b/_quarto.yml
index 93141aa9e..3e773a748 100644
--- a/_quarto.yml
+++ b/_quarto.yml
@@ -276,6 +276,7 @@ website:
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
+ - docs/gradient_checkpointing.qmd
- section: "Troubleshooting"
contents:
diff --git a/docs/gradient_checkpointing.qmd b/docs/gradient_checkpointing.qmd
new file mode 100644
index 000000000..25a887999
--- /dev/null
+++ b/docs/gradient_checkpointing.qmd
@@ -0,0 +1,29 @@
+---
+title: Gradient Checkpointing and Activation Offloading
+---
+
+Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
+models by reducing the memory footprint and improving computational efficiency.
+
+### Enabling Gradient Checkpointing
+
+```yaml
+gradient_checkpointing: true
+```
+
+### Enabling Activation Offloading
+
+```yaml
+gradient_checkpointing: true # required for activation offloading
+activation_offloading: true
+```
+
+Activation offloading variants:
+
+The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
+to overlap the communications and computations when offloading.
+
+The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
+
+For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
+activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
diff --git a/examples/cloud/modal.yaml b/examples/cloud/modal.yaml
index 195031494..bbe8785f1 100644
--- a/examples/cloud/modal.yaml
+++ b/examples/cloud/modal.yaml
@@ -26,3 +26,5 @@ timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml
index 4a30e9a77..da2777270 100644
--- a/examples/cohere/command-r-7b-qlora.yml
+++ b/examples/cohere/command-r-7b-qlora.yml
@@ -35,7 +35,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -56,3 +55,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb
index bcb99f19e..112658007 100644
--- a/examples/colab-notebooks/colab-axolotl-example.ipynb
+++ b/examples/colab-notebooks/colab-axolotl-example.ipynb
@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
- "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154\""
+ "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
]
},
{
diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
index 2c0495ced..1a051b98b 100644
--- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
index de9c956e0..807342641 100644
--- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml
index 0ed97db36..78bf6b179 100644
--- a/examples/deepseek-v2/fft-fsdp-16b.yaml
+++ b/examples/deepseek-v2/fft-fsdp-16b.yaml
@@ -55,3 +55,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
index 34dbeaafe..da1d9aefd 100644
--- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml
+++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
@@ -79,3 +79,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml
index dc0051bd5..9d92e8662 100644
--- a/examples/devstral/devstral-small-qlora.yml
+++ b/examples/devstral/devstral-small-qlora.yml
@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
index 1dd901154..484c31fec 100644
--- a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
index 24dc7cae3..dea2a6e6d 100644
--- a/examples/falcon-h1/falcon-h1-1b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
@@ -46,7 +46,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -69,3 +68,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
index 43eb1967b..b187efbf6 100644
--- a/examples/falcon-h1/falcon-h1-34b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
index 00929bbf0..4d981ad95 100644
--- a/examples/falcon-h1/falcon-h1-3b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
index e2640de7b..5ee13facd 100644
--- a/examples/falcon-h1/falcon-h1-500m-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
index 183e423b5..4b665c3cd 100644
--- a/examples/falcon-h1/falcon-h1-7b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml
index cb96a32c1..68d213fad 100644
--- a/examples/gemma2/qlora.yml
+++ b/examples/gemma2/qlora.yml
@@ -60,3 +60,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml
index ce01a4572..624ebdcd2 100644
--- a/examples/gemma2/reward-model.yaml
+++ b/examples/gemma2/reward-model.yaml
@@ -50,3 +50,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml
index 217c887aa..99921770d 100644
--- a/examples/gemma3/gemma-3-1b-qlora.yml
+++ b/examples/gemma3/gemma-3-1b-qlora.yml
@@ -66,3 +66,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml
index d78559ae3..025cb9240 100644
--- a/examples/gemma3/gemma-3-4b-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-qlora.yml
@@ -60,3 +60,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml
index 183eb88e8..e9e606b69 100644
--- a/examples/gemma3/gemma-3-4b-vision-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml
@@ -62,3 +62,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml
index 86d9b43f8..8973cedd4 100644
--- a/examples/glm4/qlora-32b.yaml
+++ b/examples/glm4/qlora-32b.yaml
@@ -60,3 +60,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml
index 2cb0eea41..494154886 100644
--- a/examples/jamba/qlora.yaml
+++ b/examples/jamba/qlora.yaml
@@ -54,3 +54,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml
index d13ce6483..64db8f2ff 100644
--- a/examples/jamba/qlora_deepspeed.yaml
+++ b/examples/jamba/qlora_deepspeed.yaml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml
index 6badaba19..fda30e2d2 100644
--- a/examples/jamba/qlora_fsdp_large.yaml
+++ b/examples/jamba/qlora_fsdp_large.yaml
@@ -64,3 +64,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/lfm2/lfm2-350m-fft.yaml
index 95961557e..74c90c1e1 100644
--- a/examples/lfm2/lfm2-350m-fft.yaml
+++ b/examples/lfm2/lfm2-350m-fft.yaml
@@ -46,3 +46,5 @@ evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml
index 86b1b6a21..c44cd2230 100644
--- a/examples/llama-2/fft_optimized.yml
+++ b/examples/llama-2/fft_optimized.yml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml
index 0f1b34016..580fabdf8 100644
--- a/examples/llama-2/gptq-lora.yml
+++ b/examples/llama-2/gptq-lora.yml
@@ -64,3 +64,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml
index a76a792ae..a44e261be 100644
--- a/examples/llama-2/lisa.yml
+++ b/examples/llama-2/lisa.yml
@@ -60,3 +60,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml
index 22dbf2d99..085627f63 100644
--- a/examples/llama-2/loftq.yml
+++ b/examples/llama-2/loftq.yml
@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml
index 679aed3a9..759fce044 100644
--- a/examples/llama-2/lora.yml
+++ b/examples/llama-2/lora.yml
@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml
index a42eabd4b..3bf30120b 100644
--- a/examples/llama-2/qlora-fsdp.yml
+++ b/examples/llama-2/qlora-fsdp.yml
@@ -67,3 +67,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml
index de65928bc..09596c71e 100644
--- a/examples/llama-2/qlora.yml
+++ b/examples/llama-2/qlora.yml
@@ -53,3 +53,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml
index e0a5f7068..ca8b14a1c 100644
--- a/examples/llama-2/relora.yml
+++ b/examples/llama-2/relora.yml
@@ -58,3 +58,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml
index 2b0ae2c70..64d749b5a 100644
--- a/examples/llama-3-vision/lora-11b.yaml
+++ b/examples/llama-3-vision/lora-11b.yaml
@@ -57,3 +57,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml
index 5d979c96c..08d8ee5c1 100644
--- a/examples/llama-3/3b-qat-fsdp2.yaml
+++ b/examples/llama-3/3b-qat-fsdp2.yaml
@@ -77,3 +77,5 @@ fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml
index eccfa6d8c..e2808935f 100644
--- a/examples/llama-3/fft-8b-liger-fsdp.yaml
+++ b/examples/llama-3/fft-8b-liger-fsdp.yaml
@@ -72,3 +72,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml
index fdae3e6c4..2dfe6d492 100644
--- a/examples/llama-3/fft-8b.yaml
+++ b/examples/llama-3/fft-8b.yaml
@@ -42,3 +42,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml
index 51f1c768b..10ab2a320 100644
--- a/examples/llama-3/instruct-dpo-lora-8b.yml
+++ b/examples/llama-3/instruct-dpo-lora-8b.yml
@@ -71,3 +71,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml
index acab862f6..83b7f9a37 100644
--- a/examples/llama-3/instruct-lora-8b.yml
+++ b/examples/llama-3/instruct-lora-8b.yml
@@ -64,3 +64,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml
index 10e9747cb..b20dbad84 100644
--- a/examples/llama-3/lora-1b-deduplicate-dpo.yml
+++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml
@@ -83,3 +83,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml
index 630ec92f6..67e518184 100644
--- a/examples/llama-3/lora-1b-deduplicate-sft.yml
+++ b/examples/llama-3/lora-1b-deduplicate-sft.yml
@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml
index a2d07ca49..92a948c2e 100644
--- a/examples/llama-3/lora-1b-kernels.yml
+++ b/examples/llama-3/lora-1b-kernels.yml
@@ -65,3 +65,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml
index bb23164eb..178a1fb89 100644
--- a/examples/llama-3/lora-1b-ray.yml
+++ b/examples/llama-3/lora-1b-ray.yml
@@ -64,3 +64,5 @@ special_tokens:
use_ray: true
ray_num_workers: 4
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
index 769dd32e6..c4ce3eb0f 100644
--- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml
+++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
@@ -63,3 +63,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml
index acc17e21f..82085483f 100644
--- a/examples/llama-3/lora-1b.yml
+++ b/examples/llama-3/lora-1b.yml
@@ -60,3 +60,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml
index ad50cd38a..c39389755 100644
--- a/examples/llama-3/lora-8b.yml
+++ b/examples/llama-3/lora-8b.yml
@@ -57,3 +57,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml
index 89a51ea68..f156e23d3 100644
--- a/examples/llama-3/qlora-1b-kto.yaml
+++ b/examples/llama-3/qlora-1b-kto.yaml
@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml
index 5c8fe6628..6b76ea8d9 100644
--- a/examples/llama-3/qlora-1b.yml
+++ b/examples/llama-3/qlora-1b.yml
@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml
index 2b7d51925..1ee922b59 100644
--- a/examples/llama-3/qlora-fsdp-405b.yaml
+++ b/examples/llama-3/qlora-fsdp-405b.yaml
@@ -60,3 +60,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml
index 412b6721c..5edd8353a 100644
--- a/examples/llama-3/qlora-fsdp-70b.yaml
+++ b/examples/llama-3/qlora-fsdp-70b.yaml
@@ -69,3 +69,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml
index 4cc9fc3db..a674eca27 100644
--- a/examples/llama-3/qlora.yml
+++ b/examples/llama-3/qlora.yml
@@ -54,3 +54,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml
index 1bbb88028..8577a19d2 100644
--- a/examples/llama-3/sparse-finetuning.yaml
+++ b/examples/llama-3/sparse-finetuning.yaml
@@ -75,3 +75,5 @@ llmcompressor:
]
start: 0
save_compressed: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
index 2be94f4ef..d4a038e11 100644
--- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
index eeae872a6..bea10d979 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
@@ -90,3 +90,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
index 17ad70634..737d93812 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
@@ -83,3 +83,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
index eff708e4d..390be5af7 100644
--- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
index 9a411883e..b319349c4 100644
--- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
+++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
@@ -84,3 +84,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml
index 20352f81e..6be3988ef 100644
--- a/examples/llama-4/scout-qlora-single-h100-flex.yaml
+++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml
@@ -82,3 +82,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
index 9fbd34107..a67936cf1 100644
--- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
+++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
@@ -87,3 +87,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml
index 5198c8e74..a4bac8987 100644
--- a/examples/llava/lora-7b.yaml
+++ b/examples/llava/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml
index b10e8baf6..b23d2309a 100644
--- a/examples/magistral/magistral-small-fsdp-qlora.yaml
+++ b/examples/magistral/magistral-small-fsdp-qlora.yaml
@@ -70,3 +70,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml
index e3e746f22..f0fce014f 100644
--- a/examples/magistral/magistral-small-qlora.yaml
+++ b/examples/magistral/magistral-small-qlora.yaml
@@ -61,3 +61,5 @@ flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml
index 3d4583932..2261bd215 100644
--- a/examples/mamba/config.yml
+++ b/examples/mamba/config.yml
@@ -48,3 +48,5 @@ weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml
index f626a92a1..e9bcbb7d6 100644
--- a/examples/mistral/bigstral-ds-zero3.yaml
+++ b/examples/mistral/bigstral-ds-zero3.yaml
@@ -53,3 +53,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml
index 15edffb44..8c4d80f79 100644
--- a/examples/mistral/config.yml
+++ b/examples/mistral/config.yml
@@ -43,3 +43,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml
index e6f46affb..d54c3e30b 100644
--- a/examples/mistral/lora-mps.yml
+++ b/examples/mistral/lora-mps.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml
index 9af4274fd..161255468 100644
--- a/examples/mistral/lora.yml
+++ b/examples/mistral/lora.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/mistral-dpo-qlora.yml
index af707973f..8d0378690 100644
--- a/examples/mistral/mistral-dpo-qlora.yml
+++ b/examples/mistral/mistral-dpo-qlora.yml
@@ -80,3 +80,5 @@ weight_decay: 0.0
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml
index e234b19a2..cec958c54 100644
--- a/examples/mistral/mistral-qlora-fsdp.yml
+++ b/examples/mistral/mistral-qlora-fsdp.yml
@@ -74,3 +74,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml
index 6c0212b7c..f37dc09fa 100644
--- a/examples/mistral/mistral-qlora-orpo.yml
+++ b/examples/mistral/mistral-qlora-orpo.yml
@@ -69,3 +69,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml
index 3e3b45862..4a492c595 100644
--- a/examples/mistral/mistral-small-3.1-24B-lora.yml
+++ b/examples/mistral/mistral-small-3.1-24B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
index af6ba5a76..64ef9930c 100644
--- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
@@ -72,3 +72,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml
index b1843a138..c8d0a2711 100644
--- a/examples/mistral/mixtral-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-qlora-fsdp.yml
@@ -77,3 +77,5 @@ fsdp_config:
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml
index 4c256420c..5be9b4db8 100644
--- a/examples/mistral/mixtral.yml
+++ b/examples/mistral/mixtral.yml
@@ -81,3 +81,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml
index 25e1d7155..100e4464f 100644
--- a/examples/mistral/mixtral_22.yml
+++ b/examples/mistral/mixtral_22.yml
@@ -51,3 +51,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml
index 607e33701..08df36e15 100644
--- a/examples/mistral/qlora.yml
+++ b/examples/mistral/qlora.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml
index 9bcbbeee0..57f65d966 100644
--- a/examples/orpheus/finetune.yml
+++ b/examples/orpheus/finetune.yml
@@ -50,3 +50,5 @@ weight_decay: 0.05
special_tokens:
pad_token:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml
index ad4ce9cd4..9f3bbdf53 100644
--- a/examples/phi/lora-3.5.yaml
+++ b/examples/phi/lora-3.5.yaml
@@ -63,3 +63,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 4
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml
index 1562a7353..fc6d649d7 100644
--- a/examples/phi/phi-ft.yml
+++ b/examples/phi/phi-ft.yml
@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml
index 4cd53db97..ccd92c817 100644
--- a/examples/phi/phi-qlora.yml
+++ b/examples/phi/phi-qlora.yml
@@ -60,3 +60,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml
index ca733cc71..853250ccb 100644
--- a/examples/phi/phi2-ft.yml
+++ b/examples/phi/phi2-ft.yml
@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml
index d0d14fea6..130298bc0 100644
--- a/examples/phi/phi3-ft-fsdp.yml
+++ b/examples/phi/phi3-ft-fsdp.yml
@@ -71,3 +71,5 @@ fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml
index 17c48da6f..42b87e8d0 100644
--- a/examples/phi/phi3-ft.yml
+++ b/examples/phi/phi3-ft.yml
@@ -59,3 +59,5 @@ warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml
index 6ad0a5e99..ea769d202 100644
--- a/examples/pixtral/lora-12b.yml
+++ b/examples/pixtral/lora-12b.yml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml
index e8932b968..8ea608199 100644
--- a/examples/qwen2-vl/lora-7b.yaml
+++ b/examples/qwen2-vl/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml
index bd896c2b3..69a74ae4a 100644
--- a/examples/qwen2/dpo.yaml
+++ b/examples/qwen2/dpo.yaml
@@ -54,3 +54,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml
index 4afa24f3c..af188f75d 100644
--- a/examples/qwen2/prm.yaml
+++ b/examples/qwen2/prm.yaml
@@ -55,3 +55,5 @@ eval_steps: 100
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml
index ed2670ab6..861ce5517 100644
--- a/examples/qwen2/qlora-fsdp.yaml
+++ b/examples/qwen2/qlora-fsdp.yaml
@@ -67,3 +67,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml
index 822407a1f..1854b8216 100644
--- a/examples/qwen2/reward-model.yaml
+++ b/examples/qwen2/reward-model.yaml
@@ -26,7 +26,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
@@ -50,3 +49,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml
index 25d02805f..13a97dec3 100644
--- a/examples/qwen2_5-vl/lora-7b.yaml
+++ b/examples/qwen2_5-vl/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml
index 45a4395ac..1f148ece5 100644
--- a/examples/qwen3/32b-qlora.yaml
+++ b/examples/qwen3/32b-qlora.yaml
@@ -67,3 +67,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml
index 6832b6af7..e4d0ed4fb 100644
--- a/examples/qwen3/8b-qat-fsdp2.yml
+++ b/examples/qwen3/8b-qat-fsdp2.yml
@@ -76,3 +76,5 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml
index dc3377b4f..762f9648d 100644
--- a/examples/qwen3/qlora-fsdp.yaml
+++ b/examples/qwen3/qlora-fsdp.yaml
@@ -66,3 +66,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/requirements.txt b/requirements.txt
index 77d6d31aa..85c7d02be 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,19 +6,19 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
-liger-kernel==0.5.10
+liger-kernel==0.6.0
# END section
packaging==23.2
-huggingface_hub==0.32.2
-peft==0.15.2
-transformers==4.53.1
+huggingface_hub>=0.33.0
+peft==0.16.0
+transformers==4.53.2
tokenizers>=0.21.1
accelerate==1.8.1
-datasets==3.6.0
+datasets==4.0.0
deepspeed>=0.17.0
-trl==0.18.2
+trl==0.19.1
hf_xet==1.1.2
optimum==1.16.2
@@ -26,7 +26,7 @@ hf_transfer
sentencepiece
gradio==5.23.3
-modal==0.70.5
+modal==1.0.2
pydantic==2.10.6
addict
fire
diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py
index 06bad8bef..6840aef50 100644
--- a/scripts/cutcrossentropy_install.py
+++ b/scripts/cutcrossentropy_install.py
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
- + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
+ + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"'
)
diff --git a/setup.py b/setup.py
index 731fe8a6f..df9a23154 100644
--- a/setup.py
+++ b/setup.py
@@ -73,9 +73,9 @@ def parse_requirements(extras_require_map):
extras_require_map["vllm"] = ["vllm>=0.9.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
- _install_requires.append(
- "xformers==0.0.29.post2"
- ) # vllm needs post2 w torch 2.6
+ _install_requires.append("xformers==0.0.29.post3")
+ # since we only support 2.6.0+cu126
+ _dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
@@ -121,7 +121,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
- "deepspeed==0.17.1",
+ "deepspeed==0.17.2",
"deepspeed-kernels",
],
"mamba-ssm": [
diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py
index 5f75352f3..cb0eece7f 100644
--- a/src/axolotl/cli/config.py
+++ b/src/axolotl/cli/config.py
@@ -16,7 +16,6 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
- migrate_fsdp_config,
normalize_cfg_datasets,
normalize_config,
validate_config,
@@ -227,7 +226,6 @@ def load_cfg(
},
)
- migrate_fsdp_config(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py
index d0c2ad165..ebadc9bf1 100644
--- a/src/axolotl/cli/preprocess.py
+++ b/src/axolotl/cli/preprocess.py
@@ -1,5 +1,6 @@
"""CLI to run preprocessing of a dataset."""
+import os
import warnings
from pathlib import Path
from typing import Union
@@ -95,6 +96,7 @@ def do_cli(
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
+ os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py
index 448b25a7e..f092cc59a 100644
--- a/src/axolotl/cli/vllm_serve.py
+++ b/src/axolotl/cli/vllm_serve.py
@@ -37,7 +37,6 @@ def do_vllm_serve(
Returns:
process_id: the process id of the started VLLM server
"""
- patch_vllm_worker()
cfg = load_cfg(config)
model = cfg.base_model
@@ -47,6 +46,9 @@ def do_vllm_serve(
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
+ data_parallel_size = (
+ cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
+ )
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
@@ -68,6 +70,7 @@ def do_vllm_serve(
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,
+ data_parallel_size=data_parallel_size,
host=host,
port=port,
gpu_memory_utilization=gpu_memory_utilization,
diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py
index 3c0ca77de..d3a3b3242 100644
--- a/src/axolotl/core/builders/base.py
+++ b/src/axolotl/core/builders/base.py
@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
+ SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -112,13 +113,6 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
- if self.cfg.profiler_steps:
- callbacks.append(
- PytorchProfilerCallback(
- steps_to_profile=self.cfg.profiler_steps,
- )
- )
-
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
@@ -142,9 +136,19 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
+ if self.cfg.save_first_step:
+ callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))
+ if self.cfg.profiler_steps:
+ callbacks.append(
+ PytorchProfilerCallback(
+ steps_to_profile=self.cfg.profiler_steps,
+ profiler_steps_start=self.cfg.profiler_steps_start,
+ )
+ )
+
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -418,6 +422,9 @@ class TrainerBuilderBase(abc.ABC):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
+ torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
+ 256
+ )
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
@@ -426,8 +433,16 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
+ def _configure_accelerator_config(self, training_args_kwargs: dict):
+ if self.cfg.accelerator_config:
+ training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
+
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
- if self.cfg.gradient_checkpointing:
+ if self.cfg.activation_offloading is True:
+ # don't use the HF gradient checkpointing, manually wrap
+ training_args_kwargs["gradient_checkpointing"] = False
+ training_args_kwargs["activation_offloading"] = True
+ elif self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
@@ -510,5 +525,6 @@ class TrainerBuilderBase(abc.ABC):
self._configure_scheduler(training_args_kwargs)
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
self._configure_torch_compile(training_args_kwargs)
+ self._configure_accelerator_config(training_args_kwargs)
return training_args_kwargs, trainer_kwargs
diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py
index 9fcd51c1d..00cee35a7 100644
--- a/src/axolotl/core/builders/causal.py
+++ b/src/axolotl/core/builders/causal.py
@@ -310,11 +310,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.neftune_noise_alpha
)
- if self.cfg.accelerator_config:
- training_arguments_kwargs["accelerator_config"] = (
- self.cfg.accelerator_config
- )
-
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py
index 81a2f5a45..b983f1076 100644
--- a/src/axolotl/core/trainers/base.py
+++ b/src/axolotl/core/trainers/base.py
@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
+ ActivationOffloadingMixin,
CheckpointSaveMixin,
OptimizerMixin,
PackingMixin,
@@ -48,6 +49,7 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
+ ActivationOffloadingMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""
@@ -75,18 +77,6 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
- def _wrap_model(self, model, training=True, dataloader=None):
- if self.args.torch_compile:
- torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
- 256
- )
- model = torch.compile(
- model,
- backend=self.args.torch_compile_backend,
- mode=self.args.torch_compile_mode,
- )
- return super()._wrap_model(model, training=training, dataloader=dataloader)
-
def _create_multipack_sampler(
self, base_sampler: Sampler, dataset: Dataset
) -> MultipackBatchSampler:
diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py
index c0f10be23..771f788fe 100644
--- a/src/axolotl/core/trainers/grpo/__init__.py
+++ b/src/axolotl/core/trainers/grpo/__init__.py
@@ -14,6 +14,7 @@ from axolotl.core.trainers.grpo.trainer import (
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig
+from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__)
@@ -41,9 +42,18 @@ class GRPOStrategy:
return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore
+ vllm_cfg: VllmConfig = cfg.vllm # type: ignore
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
+ grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
+ if trl.vllm_mode == "colocate":
+ grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
+ vllm_cfg.gpu_memory_utilization
+ )
+ grpo_args_kwargs["vllm_tensor_parallel_size"] = (
+ vllm_cfg.tensor_parallel_size
+ )
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
if trl.vllm_server_timeout:
diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py
index c97fccd31..70b3cf3b5 100644
--- a/src/axolotl/core/trainers/grpo/trainer.py
+++ b/src/axolotl/core/trainers/grpo/trainer.py
@@ -59,42 +59,6 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
- def get_train_dataloader(self):
- if self.train_dataset is None:
- raise ValueError("Trainer: training requires a train_dataset.")
-
- train_dataset = self.train_dataset
- data_collator = self.data_collator
- if isinstance(train_dataset, datasets.Dataset):
- train_dataset = self._remove_unused_columns(
- train_dataset, description="training"
- )
- else:
- data_collator = self._get_collator_with_removed_columns(
- data_collator, description="training"
- )
-
- dataloader_params = {
- "batch_size": self._train_batch_size
- * self.args.steps_per_generation, # < this is the change
- "collate_fn": data_collator,
- "num_workers": self.args.dataloader_num_workers,
- "pin_memory": self.args.dataloader_pin_memory,
- "persistent_workers": self.args.dataloader_persistent_workers,
- }
-
- if not isinstance(train_dataset, torch.utils.data.IterableDataset):
- dataloader_params["sampler"] = self._get_train_sampler()
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
- dataloader_params["worker_init_fn"] = partial(
- seed_worker,
- num_workers=self.args.dataloader_num_workers,
- rank=self.args.process_index,
- )
- dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
-
- return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
-
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
@@ -252,7 +216,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
- dataloader_params["worker_init_fn"] = seed_worker
+ dataloader_params["worker_init_fn"] = partial(
+ seed_worker,
+ num_workers=self.args.dataloader_num_workers,
+ rank=self.args.process_index,
+ )
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)
diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py
index b73b51126..453810aac 100644
--- a/src/axolotl/core/trainers/mixins/__init__.py
+++ b/src/axolotl/core/trainers/mixins/__init__.py
@@ -3,6 +3,7 @@
# pylint: disable=unused-import
# flake8: noqa
+from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin
diff --git a/src/axolotl/core/trainers/mixins/activation_checkpointing.py b/src/axolotl/core/trainers/mixins/activation_checkpointing.py
new file mode 100644
index 000000000..9488186cd
--- /dev/null
+++ b/src/axolotl/core/trainers/mixins/activation_checkpointing.py
@@ -0,0 +1,37 @@
+"""
+Trainer mixin for activation checkpointing w offloading
+"""
+
+import contextlib
+
+from torch import nn
+from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
+ apply_activation_checkpointing,
+)
+from torch.distributed.fsdp.wrap import ModuleWrapPolicy
+from transformers import GradientCheckpointingLayer, Trainer
+from trl.models.activation_offloading import get_act_offloading_ctx_manager
+
+
+class ActivationOffloadingMixin(Trainer):
+ """
+ Trainer mixin class for activation checkpointing w offloading
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.args.activation_offloading:
+ self.activation_offload_context = get_act_offloading_ctx_manager(
+ self.model, use_streams=True
+ )
+ else:
+ self.activation_offload_context = contextlib.nullcontext()
+
+ def training_step(self, *args, **kwargs):
+ with self.activation_offload_context:
+ return super().training_step(*args, **kwargs)
+
+
+def ac_wrap_hf_model(model: nn.Module, **kwargs):
+ auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
+ apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py
index 2e1987e82..4b74676ce 100644
--- a/src/axolotl/core/training_args_base.py
+++ b/src/axolotl/core/training_args_base.py
@@ -217,6 +217,11 @@ class AxolotlTrainingMixins:
},
)
+ activation_offloading: bool | None = field(
+ default=None,
+ metadata={"help": "Use activation offloading with CUDA streams for training."},
+ )
+
# multi-modal section
image_size: int | tuple[int, int] | None = field(
diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md
index 3c0a393ca..dc7c908dd 100644
--- a/src/axolotl/integrations/cut_cross_entropy/README.md
+++ b/src/axolotl/integrations/cut_cross_entropy/README.md
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
-pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
+pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"
```
## Usage
diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py
index 75b17580f..6c47097b7 100644
--- a/src/axolotl/integrations/cut_cross_entropy/__init__.py
+++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py
@@ -19,11 +19,13 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
+from functools import partial
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
+from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
@@ -32,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
- '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
+ '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19"`'
)
@@ -84,6 +86,7 @@ class CutCrossEntropyPlugin(BasePlugin):
"""Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy:
self._check_requirements()
+ self.patch_llama_like(cfg.model_config_type)
from cut_cross_entropy.transformers.patch import cce_patch
@@ -93,3 +96,48 @@ class CutCrossEntropyPlugin(BasePlugin):
# The patch checks model_type internally
cce_patch(cfg.model_config_type)
+
+ def patch_llama_like(
+ self,
+ model_type: str,
+ ) -> None:
+ """
+ Generic patch for model architectures with causal lm similar to llama
+ """
+ from cut_cross_entropy.transformers.patch import PATCH_FNS
+
+ def patch_generic(
+ maybe_model, patch_options, model_type: str
+ ): # pylint: disable=unused-argument
+ import cut_cross_entropy.transformers.llama
+ from cut_cross_entropy.transformers.llama import cce_forward
+
+ try:
+ # Dynamically import the module and CausalLM class
+ module_path = f"transformers.models.{model_type}.modeling_{model_type}"
+ model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
+ module = __import__(
+ module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]
+ )
+ model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
+
+ cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
+ patch_options
+ )
+
+ model_cls.forward = cce_forward
+ # pylint: disable=duplicate-code
+ except (ImportError, AttributeError) as e:
+ raise RuntimeError(
+ f"Could not import ForCausalLM class for model_type: {model_type}. "
+ f"Error: {str(e)}"
+ ) from e
+
+ if model_type not in PATCH_FNS:
+ LOG.warning_once(
+ "Setting up generic cce patch for model type: %s", model_type
+ )
+ LOG.warning_once(
+ f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
+ )
+ PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)
diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py
index 6a8b6da1c..4319f5f7d 100644
--- a/src/axolotl/integrations/kd/kernels/models.py
+++ b/src/axolotl/integrations/kd/kernels/models.py
@@ -22,6 +22,8 @@ except ImportError:
TransformersKwargs,
)
+from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
+
def kldiv_forward_llama_like(
self,
@@ -97,7 +99,7 @@ def kldiv_forward_llama_like(
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
- model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
+ model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like
diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py
index 8de94c78b..86d56be80 100644
--- a/src/axolotl/integrations/liger/__init__.py
+++ b/src/axolotl/integrations/liger/__init__.py
@@ -18,170 +18,10 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
-import inspect
-import sys
+from .args import LigerArgs
+from .plugin import LigerPlugin
-from axolotl.integrations.base import BasePlugin
-from axolotl.utils.logging import get_logger
-
-from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
-from .utils import patch_with_compile_disable
-
-LOG = get_logger(__name__)
-
-
-class LigerPlugin(BasePlugin):
- """
- Plugin for LIGER integraton with Axolotl.
- """
-
- def get_input_args(self):
- return "axolotl.integrations.liger.LigerArgs"
-
- def pre_model_load(self, cfg):
- if cfg.torch_compile:
- # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
- import liger_kernel.ops.fused_linear_cross_entropy
-
- patch_with_compile_disable(
- liger_kernel.ops.fused_linear_cross_entropy,
- "fused_linear_cross_entropy_forward",
- )
- patch_with_compile_disable(
- liger_kernel.ops.fused_linear_cross_entropy,
- "fused_linear_cross_entropy_backward",
- )
- from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
- from liger_kernel.transformers.functional import liger_cross_entropy
- from liger_kernel.transformers.layer_norm import LigerLayerNorm
- from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
- from liger_kernel.transformers.rms_norm import LigerRMSNorm
- from liger_kernel.transformers.rope import liger_rotary_pos_emb
- from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
-
- if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
- raise ValueError(
- "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
- )
-
- if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
- apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
- liger_fn_sig = inspect.signature(apply_liger_fn)
- kwargs = {}
- if "rope" in liger_fn_sig.parameters:
- kwargs["rope"] = cfg.liger_rope
- if "cross_entropy" in liger_fn_sig.parameters:
- kwargs["cross_entropy"] = cfg.liger_cross_entropy
- if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
- kwargs["fused_linear_cross_entropy"] = (
- cfg.liger_fused_linear_cross_entropy
- )
- if "rms_norm" in liger_fn_sig.parameters:
- kwargs["rms_norm"] = cfg.liger_rms_norm
- if "layer_norm" in liger_fn_sig.parameters:
- kwargs["layer_norm"] = cfg.liger_layer_norm
- if "geglu" in liger_fn_sig.parameters:
- kwargs["geglu"] = cfg.liger_glu_activation
- elif "swiglu" in liger_fn_sig.parameters:
- kwargs["swiglu"] = cfg.liger_glu_activation
- LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
- apply_liger_fn(**kwargs)
- elif cfg.model_config_type == "jamba":
- from transformers.models.jamba import modeling_jamba
-
- from .models.jamba import lce_forward as jamba_lce_forward
-
- if cfg.liger_rope:
- modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
- if cfg.liger_rms_norm:
- modeling_jamba.JambaRMSNorm = LigerRMSNorm
- if cfg.liger_glu_activation:
- modeling_jamba.JambaMLP = LigerSwiGLUMLP
- if cfg.liger_layer_norm:
- modeling_jamba.nn.LayerNorm = LigerLayerNorm
- if cfg.liger_cross_entropy:
- from transformers.loss.loss_utils import nn
-
- nn.functional.cross_entropy = liger_cross_entropy
- if cfg.liger_fused_linear_cross_entropy:
- modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
- elif cfg.model_config_type == "deepseek_v2":
- from accelerate import init_empty_weights
- from transformers import AutoModelForCausalLM
-
- with init_empty_weights():
- model = AutoModelForCausalLM.from_pretrained(
- cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
- )
- modeling_mod = sys.modules[model.__class__.__module__]
-
- from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
-
- if cfg.liger_rope:
- # The DeepseekV2 version of RoPE is different than upstream LLaMA.
- # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
- LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
- if cfg.liger_glu_activation:
- LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
- if cfg.liger_rms_norm:
- modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
- if cfg.liger_glu_activation:
- modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
- if cfg.liger_layer_norm:
- modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
- if cfg.liger_cross_entropy:
- # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
- # nn.CrossEntropyLoss in the forward method.
- modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
- if cfg.liger_fused_linear_cross_entropy:
- modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
- elif cfg.model_config_type == "llama4":
- from axolotl.integrations.liger.models.llama4 import (
- apply_liger_kernel_to_llama4,
- )
-
- apply_liger_kernel_to_llama4(
- cross_entropy=cfg.liger_cross_entropy,
- fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
- glu_activation=cfg.liger_glu_activation,
- rms_norm=cfg.liger_rms_norm,
- layer_norm=cfg.liger_layer_norm,
- )
- elif cfg.model_config_type == "qwen3":
- from axolotl.integrations.liger.models.qwen3 import (
- apply_liger_kernel_to_qwen3,
- )
-
- apply_liger_kernel_to_qwen3(
- cross_entropy=cfg.liger_cross_entropy,
- fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
- glu_activation=cfg.liger_glu_activation,
- rms_norm=cfg.liger_rms_norm,
- layer_norm=cfg.liger_layer_norm,
- )
- elif cfg.model_config_type == "qwen3_moe":
- from axolotl.integrations.liger.models.qwen3_moe import (
- apply_liger_kernel_to_qwen3_moe,
- )
-
- apply_liger_kernel_to_qwen3_moe(
- cross_entropy=cfg.liger_cross_entropy,
- fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
- glu_activation=cfg.liger_glu_activation,
- rms_norm=cfg.liger_rms_norm,
- layer_norm=cfg.liger_layer_norm,
- )
- elif cfg.model_config_type == "granitemoe":
- from liger_kernel.transformers import apply_liger_kernel_to_granite
-
- apply_liger_kernel_to_granite(
- rope=cfg.liger_rope,
- cross_entropy=cfg.liger_cross_entropy,
- fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
- rms_norm=cfg.liger_rms_norm,
- swiglu=cfg.liger_glu_activation,
- )
- else:
- LOG.warning(
- f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
- )
+__all__ = [
+ "LigerArgs",
+ "LigerPlugin",
+]
diff --git a/src/axolotl/integrations/liger/models/base.py b/src/axolotl/integrations/liger/models/base.py
new file mode 100644
index 000000000..f3cf4299a
--- /dev/null
+++ b/src/axolotl/integrations/liger/models/base.py
@@ -0,0 +1,189 @@
+"""
+Generic FLCE patch for untested models similar to Llama
+"""
+
+from typing import Optional, Tuple, Union
+
+import torch
+from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
+from liger_kernel.transformers.trainer.orpo_trainer import _FSDPForwardRedirection
+from liger_kernel.utils import PEFT_AVAILABLE
+from peft.utils import ModulesToSaveWrapper
+from torch.distributed.fsdp import FullyShardedDataParallel
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
+
+
+def lce_forward(
+ self,
+ *args,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ skip_logits: Optional[bool] = None,
+ **kwargs,
+) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+ """
+
+ # pylint: disable=duplicate-code
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ *args,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = (
+ slice(-logits_to_keep, None)
+ if isinstance(logits_to_keep, int)
+ else logits_to_keep
+ )
+ kept_hidden_states = hidden_states[:, slice_indices, :]
+
+ shift_labels = kwargs.pop("shift_labels", None)
+ logits = None
+ loss = None
+
+ # if in training mode, don't materialize logits
+ if skip_logits and labels is None and shift_labels is None:
+ raise ValueError("skip_logits is True, but labels and shift_labels are None")
+
+ if skip_logits is None:
+ # By default, if in training mode, don't materialize logits
+ skip_logits = self.training and (labels is not None or shift_labels is not None)
+
+ if skip_logits:
+ loss = lce_maybe_trainable_lm_head(
+ self,
+ hidden_states=kept_hidden_states,
+ hidden_size=self.config.hidden_size,
+ labels=labels,
+ shift_labels=shift_labels,
+ **kwargs,
+ )
+
+ else:
+ logits = self.lm_head(kept_hidden_states)
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits,
+ labels=labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def lce_maybe_trainable_lm_head(
+ self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs
+):
+ lm_head = self.lm_head
+
+ # Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
+ # i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
+ # from the unwrapped module.
+ # See https://huggingface.co/docs/peft/package_reference/lora for reference.
+ if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
+ lm_head = lm_head.modules_to_save.default
+
+ # If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
+ # reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
+ # so the module entire parameters are summoned and kept in memory during the kernel execution.
+ if isinstance(lm_head, FullyShardedDataParallel):
+ return _FSDPForwardRedirection()(
+ lm_head,
+ _liger_for_causal_lm_loss,
+ lm_head.module,
+ hidden_states,
+ hidden_size,
+ labels,
+ shift_labels,
+ **loss_kwargs,
+ )
+
+ # FSDP is not used so we can read the lm_head weights and call the kernel directly
+ return _liger_for_causal_lm_loss(
+ lm_head=self.lm_head,
+ hidden_states=hidden_states,
+ hidden_size=hidden_size,
+ labels=labels,
+ shift_labels=shift_labels,
+ **loss_kwargs,
+ )
+
+
+def _liger_for_causal_lm_loss(
+ lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs
+):
+ return LigerForCausalLMLoss(
+ hidden_states=hidden_states,
+ lm_head_weight=lm_head.weight,
+ labels=labels,
+ hidden_size=hidden_size,
+ shift_labels=shift_labels,
+ **loss_kwargs,
+ )
+
+
+def patch_lce_forward(
+ model_type,
+):
+ try:
+ # Dynamically import the module and MLP class
+ module_path = f"transformers.models.{model_type}.modeling_{model_type}"
+ model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
+ module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
+ model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
+
+ model_cls.forward = lce_forward
+ # pylint: disable=duplicate-code
+ except (ImportError, AttributeError) as e:
+ raise RuntimeError(
+ f"Could not import ForCausalLM class for model_type: {model_type}. "
+ f"Error: {str(e)}"
+ ) from e
diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py
new file mode 100644
index 000000000..89f7c37b7
--- /dev/null
+++ b/src/axolotl/integrations/liger/plugin.py
@@ -0,0 +1,182 @@
+"""
+Liger-Kernel Plugin for Axolotl
+"""
+
+import inspect
+import sys
+
+from axolotl.integrations.base import BasePlugin
+from axolotl.utils.logging import get_logger
+
+from .models.base import patch_lce_forward
+from .utils import patch_with_compile_disable
+
+LOG = get_logger(__name__)
+
+
+class LigerPlugin(BasePlugin):
+ """
+ Plugin for LIGER integraton with Axolotl.
+ """
+
+ def get_input_args(self):
+ return "axolotl.integrations.liger.LigerArgs"
+
+ def pre_model_load(self, cfg):
+ if cfg.torch_compile:
+ # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
+ import liger_kernel.ops.fused_linear_cross_entropy
+
+ patch_with_compile_disable(
+ liger_kernel.ops.fused_linear_cross_entropy,
+ "fused_linear_cross_entropy_forward",
+ )
+ patch_with_compile_disable(
+ liger_kernel.ops.fused_linear_cross_entropy,
+ "fused_linear_cross_entropy_backward",
+ )
+ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
+ from liger_kernel.transformers.functional import liger_cross_entropy
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
+ from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
+ from liger_kernel.transformers.rope import liger_rotary_pos_emb
+ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
+
+ if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
+ raise ValueError(
+ "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
+ )
+
+ if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
+ apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
+ liger_fn_sig = inspect.signature(apply_liger_fn)
+ kwargs = {}
+ if "rope" in liger_fn_sig.parameters:
+ kwargs["rope"] = cfg.liger_rope
+ if "cross_entropy" in liger_fn_sig.parameters:
+ kwargs["cross_entropy"] = cfg.liger_cross_entropy
+ if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
+ kwargs["fused_linear_cross_entropy"] = (
+ cfg.liger_fused_linear_cross_entropy
+ )
+ if "rms_norm" in liger_fn_sig.parameters:
+ kwargs["rms_norm"] = cfg.liger_rms_norm
+ if "layer_norm" in liger_fn_sig.parameters:
+ kwargs["layer_norm"] = cfg.liger_layer_norm
+ if "geglu" in liger_fn_sig.parameters:
+ kwargs["geglu"] = cfg.liger_glu_activation
+ elif "swiglu" in liger_fn_sig.parameters:
+ kwargs["swiglu"] = cfg.liger_glu_activation
+ LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
+ apply_liger_fn(**kwargs)
+ elif cfg.model_config_type == "jamba":
+ from transformers.models.jamba import modeling_jamba
+
+ from .models.jamba import lce_forward as jamba_lce_forward
+
+ if cfg.liger_rope:
+ modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
+ if cfg.liger_rms_norm:
+ modeling_jamba.JambaRMSNorm = LigerRMSNorm
+ if cfg.liger_glu_activation:
+ modeling_jamba.JambaMLP = LigerSwiGLUMLP
+ if cfg.liger_layer_norm:
+ modeling_jamba.nn.LayerNorm = LigerLayerNorm
+ if cfg.liger_cross_entropy:
+ from transformers.loss.loss_utils import nn
+
+ nn.functional.cross_entropy = liger_cross_entropy
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
+ elif cfg.model_config_type == "deepseek_v2":
+ from accelerate import init_empty_weights
+ from transformers import AutoModelForCausalLM
+
+ with init_empty_weights():
+ model = AutoModelForCausalLM.from_pretrained(
+ cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
+ )
+ modeling_mod = sys.modules[model.__class__.__module__]
+
+ from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
+
+ if cfg.liger_rope:
+ # The DeepseekV2 version of RoPE is different than upstream LLaMA.
+ # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
+ LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
+ if cfg.liger_rms_norm:
+ modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
+ if cfg.liger_glu_activation:
+ modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
+ if cfg.liger_layer_norm:
+ LOG.warning("liger_layer_norm is not supported for DeepseekV2.")
+ if cfg.liger_cross_entropy:
+ # We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
+ # nn.CrossEntropyLoss in the forward method.
+ modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
+ if cfg.liger_fused_linear_cross_entropy:
+ modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
+ elif cfg.model_config_type == "llama4":
+ from axolotl.integrations.liger.models.llama4 import (
+ apply_liger_kernel_to_llama4,
+ )
+
+ apply_liger_kernel_to_llama4(
+ cross_entropy=cfg.liger_cross_entropy,
+ fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
+ glu_activation=cfg.liger_glu_activation,
+ rms_norm=cfg.liger_rms_norm,
+ layer_norm=cfg.liger_layer_norm,
+ )
+ elif cfg.model_config_type == "qwen3":
+ from axolotl.integrations.liger.models.qwen3 import (
+ apply_liger_kernel_to_qwen3,
+ )
+
+ apply_liger_kernel_to_qwen3(
+ cross_entropy=cfg.liger_cross_entropy,
+ fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
+ glu_activation=cfg.liger_glu_activation,
+ rms_norm=cfg.liger_rms_norm,
+ layer_norm=cfg.liger_layer_norm,
+ )
+ elif cfg.model_config_type == "qwen3_moe":
+ from axolotl.integrations.liger.models.qwen3_moe import (
+ apply_liger_kernel_to_qwen3_moe,
+ )
+
+ apply_liger_kernel_to_qwen3_moe(
+ cross_entropy=cfg.liger_cross_entropy,
+ fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
+ glu_activation=cfg.liger_glu_activation,
+ rms_norm=cfg.liger_rms_norm,
+ layer_norm=cfg.liger_layer_norm,
+ )
+ elif cfg.model_config_type == "granitemoe":
+ from liger_kernel.transformers import apply_liger_kernel_to_granite
+
+ apply_liger_kernel_to_granite(
+ rope=cfg.liger_rope,
+ cross_entropy=cfg.liger_cross_entropy,
+ fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
+ rms_norm=cfg.liger_rms_norm,
+ swiglu=cfg.liger_glu_activation,
+ )
+ elif cfg.liger_fused_linear_cross_entropy:
+ try:
+ patch_lce_forward(cfg.model_config_type)
+ LOG.warning_once(
+ f"Applied ONLY liger_fused_linear_cross_entropy genericpatches for model type: {cfg.model_config_type}"
+ )
+ LOG.warning_once(
+ f"Liger + {cfg.model_config_type} generic FLCE support is experimental and may not work as expected."
+ )
+ except RuntimeError:
+ LOG.warning(
+ f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
+ )
+ else:
+ LOG.warning(
+ f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
+ )
diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py
index 03678e1b4..1ce98ef31 100644
--- a/src/axolotl/loaders/model.py
+++ b/src/axolotl/loaders/model.py
@@ -198,12 +198,22 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
+ self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
self._configure_embedding_dtypes()
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
+ def _apply_activation_checkpointing(self):
+ if self.cfg.activation_offloading is True:
+ from axolotl.core.trainers.mixins.activation_checkpointing import (
+ ac_wrap_hf_model,
+ )
+
+ # ^^ importing this at the module level breaks plugins
+ ac_wrap_hf_model(self.model)
+
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
embeddings_len = (
diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py
index 2544429e6..f346c56e0 100644
--- a/src/axolotl/loaders/patch_manager.py
+++ b/src/axolotl/loaders/patch_manager.py
@@ -7,7 +7,6 @@ import importlib.util
from functools import cached_property
import addict
-import torch
import transformers
from transformers import PretrainedConfig, PreTrainedModel
@@ -168,28 +167,19 @@ class PatchManager:
def _apply_gradient_checkpointing_patches(self):
"""Apply patches for gradient checkpointing."""
- if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
+ if (
+ self.cfg.gradient_checkpointing
+ and self.cfg.activation_offloading == "legacy"
+ ):
from axolotl.monkeypatch.gradient_checkpointing import (
- CheckpointFunctionWithCPUOffload,
hf_grad_checkpoint_offload_wrapper,
)
- if (
- self.cfg.gradient_checkpointing_kwargs
- and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
- and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
- ):
- transformers.modeling_utils.checkpoint = (
- hf_grad_checkpoint_offload_wrapper
- )
- else:
- transformers.modeling_utils.checkpoint.CheckpointFunction = (
- CheckpointFunctionWithCPUOffload
- )
- torch.utils.checkpoint.CheckpointFunction = (
- CheckpointFunctionWithCPUOffload
- )
- if self.cfg.gradient_checkpointing == "offload_disk":
+ transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
+ elif (
+ self.cfg.gradient_checkpointing
+ and self.cfg.activation_offloading == "offload_disk"
+ ):
from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
)
@@ -282,7 +272,11 @@ class PatchManager:
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
- patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
+ patch_tiled_mlp(
+ model_type,
+ use_original_mlp=self.cfg.tiled_mlp_use_original_mlp,
+ cfg_num_shards=self.cfg.tiled_mlp_num_shards,
+ )
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py
index 6ca8e0240..3b090d5e5 100644
--- a/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py
+++ b/src/axolotl/monkeypatch/gradient_checkpointing/__init__.py
@@ -6,7 +6,6 @@ from functools import partial
from packaging import version
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
- CheckpointFunctionWithCPUOffload,
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
diff --git a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py
index 432cafb35..bbcfb91e6 100644
--- a/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py
+++ b/src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py
@@ -14,18 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import contextlib
import inspect
import torch
from packaging import version
from torch.utils.checkpoint import (
- _get_autocast_kwargs,
- _get_device_module,
- _infer_device_type,
- check_backward_validity,
- detach_variable,
- get_device_states,
set_device_states,
)
@@ -76,153 +69,3 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
) + (
None,
) * len(ctx.args)
-
-
-# Copyright 2025 Snowflake Inc.
-# SPDX-License-Identifier: Apache-2.0
-# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
-class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
- """
- This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
- In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
- """
-
- @staticmethod
- def forward(ctx, run_function, preserve_rng_state, *args):
- check_backward_validity(args)
- ctx.run_function = run_function
- ctx.preserve_rng_state = preserve_rng_state
- # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
- ctx.device_type = _infer_device_type(*args)
- ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
- ctx.device_type
- )
- if preserve_rng_state:
- ctx.fwd_cpu_state = torch.get_rng_state()
- # Don't eagerly initialize the cuda context by accident.
- # (If the user intends that the context is initialized later, within their
- # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
- # we have no way to anticipate this will happen before we run the function.)
- ctx.had_device_in_fwd = False
- device_module = _get_device_module(ctx.device_type)
- if getattr(device_module, "_initialized", False):
- ctx.had_device_in_fwd = True
- ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
-
- # Save non-tensor inputs in ctx, keep a placeholder None for tensors
- # to be filled out during the backward.
- ctx.inputs = []
- ctx.tensor_indices = []
- tensor_inputs = []
- # x = None
- for i, arg in enumerate(args):
- if torch.is_tensor(arg):
- # cpu-offload
- # we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
- # upstream could accept a list of arg indices to offload
- if i == 0:
- # print(f"{arg.shape=}")
- ctx.x_device = arg.device
- ctx.x_requires_grad = arg.requires_grad
- t = arg.detach().cpu()
- else:
- t = arg
- tensor_inputs.append(t)
- ctx.tensor_indices.append(i)
- ctx.inputs.append(None)
- else:
- ctx.inputs.append(arg)
-
- ctx.save_for_backward(*tensor_inputs)
-
- with torch.no_grad():
- outputs = run_function(*args)
-
- return outputs
-
- @staticmethod
- def backward(ctx, *args):
- if (
- not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
- ):
- raise RuntimeError(
- "When use_reentrant=True, torch.utils.checkpoint is incompatible"
- " with .grad() or passing an `inputs` parameter to .backward()."
- " To resolve this error, you can either set use_reentrant=False,"
- " or call .backward() without passing the `inputs` argument."
- )
- # Copy the list to avoid modifying original list.
- inputs = list(ctx.inputs)
- tensor_indices = ctx.tensor_indices
- tensors = ctx.saved_tensors
-
- # Fill in inputs with appropriate saved tensors.
- for i, idx in enumerate(tensor_indices):
- if i == 0:
- t = (
- tensors[i]
- .to(ctx.x_device)
- .detach()
- .requires_grad_(ctx.x_requires_grad)
- )
- else:
- t = tensors[i]
- inputs[idx] = t
-
- # Stash the surrounding rng state, and mimic the state that was
- # present at this time during forward. Restore the surrounding state
- # when we're done.
- rng_devices = []
- if ctx.preserve_rng_state and ctx.had_device_in_fwd:
- rng_devices = ctx.fwd_devices
- with torch.random.fork_rng(
- devices=rng_devices,
- enabled=ctx.preserve_rng_state,
- device_type=ctx.device_type,
- ):
- if ctx.preserve_rng_state:
- torch.set_rng_state(ctx.fwd_cpu_state)
- if ctx.had_device_in_fwd:
- if has_device_type:
- # newer pytorch (as early as 2.7)
- set_device_states(
- ctx.fwd_devices,
- ctx.fwd_device_states,
- device_type=ctx.device_type,
- )
- else:
- # older pytorch (at least 2.4)
- set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
- detached_inputs = detach_variable(tuple(inputs))
-
- device_autocast_ctx = (
- torch.amp.autocast(
- device_type=ctx.device_type, **ctx.device_autocast_kwargs
- )
- if torch.amp.is_autocast_available(ctx.device_type)
- else contextlib.nullcontext()
- )
- with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
- outputs = ctx.run_function(*detached_inputs)
-
- if isinstance(outputs, torch.Tensor):
- outputs = (outputs,)
-
- # run backward() with only tensor that requires grad
- outputs_with_grad = []
- args_with_grad = []
- for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
- if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
- outputs_with_grad.append(outputs[i])
- args_with_grad.append(args[i])
- if len(outputs_with_grad) == 0:
- raise RuntimeError(
- "none of output has requires_grad=True, this checkpoint() is not necessary"
- )
- torch.autograd.backward(outputs_with_grad, args_with_grad)
- grads = tuple(
- inp.grad if isinstance(inp, torch.Tensor) else None
- for inp in detached_inputs
- )
-
- return (None, None) + grads
diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py
index 586412dd7..4702ad19d 100644
--- a/src/axolotl/monkeypatch/lora_kernels.py
+++ b/src/axolotl/monkeypatch/lora_kernels.py
@@ -18,6 +18,7 @@ from axolotl.kernels.lora import (
apply_lora_qkv,
)
from axolotl.monkeypatch.utils import detab_code
+from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -153,9 +154,7 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
- model_cls_prefix = "".join(
- [part.capitalize() for part in model_type.split("_")]
- )
+ model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py
index 99a10df9c..3818c6b35 100644
--- a/src/axolotl/monkeypatch/tiled_mlp.py
+++ b/src/axolotl/monkeypatch/tiled_mlp.py
@@ -6,6 +6,8 @@ import os
import torch
import torch.distributed as dist
+from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
+
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
@@ -13,9 +15,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
- model_cls_prefix = "".join(
- [part.capitalize() for part in model_type.split("_")]
- )
+ model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
@@ -45,11 +45,12 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
else:
num_shards = cfg_num_shards
- compute_params = [
- self.down_proj.weight,
- self.gate_proj.weight,
- self.up_proj.weight,
- ]
+ if not self._compute_params: # pylint: disable=protected-access
+ self._compute_params = [ # pylint: disable=protected-access
+ p for p in self.parameters() if p.requires_grad
+ ]
+
+ compute_params = self._compute_params # pylint: disable=protected-access
down_res = TiledMLP.apply(
mlp_forward,
@@ -61,6 +62,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
return down_res
mlp_cls.forward = tiled_mlp_forward
+ mlp_cls._compute_params = [] # pylint: disable=protected-access
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "
diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index a9d26a650..ced8c8da6 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -379,6 +379,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Public method that can handle either a single prompt or a batch of prompts.
"""
+ def _remove_none_values(obj):
+ """
+ Remove null from a dictionary-like obj or list.
+ These can appear due to Dataset loading causing schema merge.
+ See https://github.com/axolotl-ai-cloud/axolotl/pull/2909
+ """
+ if hasattr(obj, "items"):
+ return {
+ k: _remove_none_values(v) for k, v in obj.items() if v is not None
+ }
+ if isinstance(obj, list):
+ return [_remove_none_values(elem) for elem in obj]
+ return obj
+
+ prompt = _remove_none_values(prompt)
+
if not self.is_prompt_batched(prompt) or not self.supports_batched:
return self._tokenize_single_prompt(prompt)
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 35c58501c..967179903 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -224,6 +224,9 @@ def execute_training(
# torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
+ plugin_manager = PluginManager.get_instance()
+ plugin_manager.post_train(cfg, trainer.model)
+
def save_trained_model(
cfg: DictDefault,
@@ -510,6 +513,9 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
peft_config=peft_config,
)
+ plugin_manager = PluginManager.get_instance()
+ plugin_manager.post_trainer_create(cfg, trainer)
+
return (
trainer,
model,
@@ -541,9 +547,6 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
- plugin_manager = PluginManager.get_instance()
- plugin_manager.post_trainer_create(cfg, trainer)
-
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
@@ -566,6 +569,4 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
- plugin_manager.post_train(cfg, model)
-
return model, tokenizer, trainer
diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py
index 2a93ceef5..bb777fc90 100644
--- a/src/axolotl/utils/callbacks/__init__.py
+++ b/src/axolotl/utils/callbacks/__init__.py
@@ -64,7 +64,7 @@ class SaveBetterTransformerModelCallback(
state: TrainerState,
control: TrainerControl,
**kwargs,
- ):
+ ) -> TrainerControl:
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
@@ -100,11 +100,11 @@ class GPUStatsCallback(
def on_step_end(
self,
- args: TrainingArguments,
+ args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**kwargs,
- ):
+ ) -> TrainerControl:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
@@ -116,18 +116,17 @@ class LossWatchDogCallback(TrainerCallback):
def __init__(self, cfg):
self.cfg = cfg
- self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
- _args: TrainingArguments,
+ args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**_kwargs,
- ):
+ ) -> TrainerControl:
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
@@ -141,6 +140,21 @@ class LossWatchDogCallback(TrainerCallback):
return control
+class SaveModelOnFirstStepCallback(TrainerCallback):
+ """Callback to save the model on the first step of training if enabled"""
+
+ def on_step_end(
+ self,
+ args: TrainingArguments, # pylint: disable=unused-argument
+ state: TrainerState,
+ control: TrainerControl,
+ **_kwargs,
+ ) -> TrainerControl:
+ if state.global_step == 1:
+ control.should_save = True
+ return control
+
+
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
@@ -841,21 +855,35 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""
- def __init__(self, gc_steps=None):
- self.gc_steps = gc_steps
+ def __init__(self, gc_steps: int | None = -1):
+ self.gc_steps: int = gc_steps or -1
+ self.next_gc_on_begin_step: int = -1
+
+ def _gc(self):
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def on_step_begin(
+ self, args, state, control, **kwargs # pylint: disable=unused-argument
+ ):
+ if self.next_gc_on_begin_step == state.global_step:
+ self._gc()
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
- if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
- torch.cuda.empty_cache()
- gc.collect()
+ if control.should_evaluate:
+ # automatically GC before evals so the eval memory spike from the CEL doesn't OOM the trainer
+ self._gc()
+ # also GC on the start of the next step after the eval
+ self.next_gc_on_begin_step = state.global_step + 1
+ elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
+ self._gc()
def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
- torch.cuda.empty_cache()
- gc.collect()
+ self._gc()
def colab_inference_post_train_callback(trainer: Trainer):
diff --git a/src/axolotl/utils/callbacks/models.py b/src/axolotl/utils/callbacks/models.py
new file mode 100644
index 000000000..5a20d70d9
--- /dev/null
+++ b/src/axolotl/utils/callbacks/models.py
@@ -0,0 +1,23 @@
+"""Helper functions for model classes"""
+
+from typing import Tuple
+
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+
+
+def get_causal_lm_model_cls_prefix(model_type: str) -> Tuple[str, str]:
+ if model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
+ causal_lm_cls = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
+ causal_lm_cls_prefix = causal_lm_cls
+ for suffix in [
+ "ForCausalLM",
+ "ForConditionalGeneration",
+ "LMHeadModel",
+ "GenerationDecoder",
+ ]:
+ causal_lm_cls_prefix = causal_lm_cls_prefix.replace(suffix, "")
+ return causal_lm_cls_prefix, causal_lm_cls
+ causal_lm_cls_prefix = "".join(
+ [part.capitalize() for part in model_type.split("_")]
+ )
+ return causal_lm_cls_prefix, f"{causal_lm_cls_prefix}ForCausalLM"
diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py
index 36604813f..d26b7f9dd 100644
--- a/src/axolotl/utils/callbacks/profiler.py
+++ b/src/axolotl/utils/callbacks/profiler.py
@@ -19,9 +19,27 @@ class PytorchProfilerCallback(TrainerCallback):
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
"""
- def __init__(self, steps_to_profile: int = 5):
- self.steps_to_profile = steps_to_profile
- if self.steps_to_profile:
+ def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
+ # steps are 0 indexed, so to start at 0-th step, we start at beginning of first step,
+ # and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4]
+ self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1
+ if profiler_steps_start == 0:
+ # start recording memory allocations before everything is allocated, because if we start
+ # at the beginning of step 0, we won't have any memory allocations in the traces
+ torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
+ enabled="all"
+ )
+ profiler_steps_start = -1
+ self.profiler_steps_start = profiler_steps_start
+
+ def on_step_begin( # pylint: disable=unused-argument
+ self,
+ args: TrainingArguments, # pylint: disable=unused-argument
+ state: TrainerState,
+ control: TrainerControl, # pylint: disable=unused-argument
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all"
)
@@ -33,7 +51,28 @@ class PytorchProfilerCallback(TrainerCallback):
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
- if state.global_step == self.steps_to_profile:
+ if state.global_step == self.profiler_steps_end:
+ snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
+ with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
+ dump(snapshot, fout)
+
+ # tell CUDA to stop recording memory allocations now
+ torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
+ enabled=None
+ )
+
+ def on_train_end( # pylint: disable=unused-argument
+ self,
+ args: TrainingArguments, # pylint: disable=unused-argument
+ state: TrainerState,
+ control: TrainerControl, # pylint: disable=unused-argument
+ **kwargs, # pylint: disable=unused-argument
+ ):
+ # make sure to record if we happen to have more steps than steps to profile
+ if (
+ state.global_step >= self.profiler_steps_start
+ and state.global_step < self.profiler_steps_end
+ ):
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
dump(snapshot, fout)
diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py
index 4de606565..aaa203e82 100644
--- a/src/axolotl/utils/config/__init__.py
+++ b/src/axolotl/utils/config/__init__.py
@@ -115,6 +115,7 @@ def normalize_config(cfg):
"chrf",
]
choose_device(cfg)
+ cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.world_size != 1:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
if cfg.fsdp or cfg.fsdp_config or cfg.ddp:
@@ -313,16 +314,3 @@ def prepare_plugins(cfg):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
-
-
-# TODO @SalmanMohammadi remove this function in 0.12
-def migrate_fsdp_config(cfg):
- if cfg.get("fsdp_config"):
- fsdp_config_keys = cfg.fsdp_config.keys()
- if "fsdp_version" in fsdp_config_keys:
- cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")
-
- for key in list(fsdp_config_keys):
- if key.startswith("fsdp_") and key != "fsdp_version":
- cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
- del cfg.fsdp_config[key]
diff --git a/src/axolotl/utils/mistral_tokenizer.py b/src/axolotl/utils/mistral_tokenizer.py
index 95c87a822..33c08db46 100644
--- a/src/axolotl/utils/mistral_tokenizer.py
+++ b/src/axolotl/utils/mistral_tokenizer.py
@@ -497,3 +497,131 @@ class HFMistralTokenizer:
return [
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
]
+
+ def __call__(
+ self,
+ text: str | list[str],
+ add_special_tokens: bool = True,
+ padding: bool | str = False,
+ truncation: bool = False,
+ max_length: int | None = None,
+ return_tensors: str | None = None,
+ **kwargs,
+ ) -> dict[str, list[int] | np.ndarray | Tensor]:
+ """
+ Tokenize text and return a dictionary with input_ids and attention_mask.
+
+ Args:
+ text: Input text string or list of strings to tokenize.
+ add_special_tokens: Whether to add special tokens (BOS/EOS).
+ padding: Whether to pad sequences. Can be True, False, "longest", or "max_length".
+ truncation: Whether to truncate sequences to max_length.
+ max_length: Maximum sequence length for truncation/padding.
+ return_tensors: Return format ("pt" for PyTorch, "np" for NumPy, None for lists).
+
+ Returns:
+ Dictionary with "input_ids" and "attention_mask" keys.
+ """
+ # if kwargs passed, raise error
+ if kwargs:
+ raise ValueError(
+ f"Unsupported kwargs: {kwargs}. Please create an issue on GitHub."
+ )
+
+ # `np` can work with inhomogeneous shapes but let's not support it until needed.
+ if (
+ isinstance(text, list)
+ and len(text) > 1
+ and return_tensors in ("pt", "np")
+ and padding is False
+ and truncation is False
+ ):
+ raise ValueError(
+ "return_tensors='pt' or 'np' requires padding or truncation."
+ )
+
+ # Handle single string input
+ if isinstance(text, str):
+ text = [text]
+
+ # Encode all texts
+ # TODO: figure out how to parallelize this
+ batch_input_ids = []
+ for single_text in text:
+ input_ids = self.encode(single_text, add_special_tokens=add_special_tokens)
+
+ # Handle truncation
+ if truncation and max_length is not None and len(input_ids) > max_length:
+ input_ids = input_ids[:max_length]
+
+ batch_input_ids.append(input_ids)
+
+ # Create attention masks (1 for real tokens, 0 for padding)
+ attention_masks = [[1] * len(input_ids) for input_ids in batch_input_ids]
+
+ # Handle padding
+ if padding in (True, "longest"):
+ # Pad to longest sequence in batch
+ max_len = max(len(input_ids) for input_ids in batch_input_ids)
+
+ for i, input_ids in enumerate(batch_input_ids):
+ pad_length = max_len - len(input_ids)
+ if pad_length > 0:
+ if self.padding_side == "right":
+ batch_input_ids[i] = (
+ input_ids + [self.pad_token_id] * pad_length
+ )
+ attention_masks[i] = attention_masks[i] + [0] * pad_length
+ else: # left padding
+ batch_input_ids[i] = [
+ self.pad_token_id
+ ] * pad_length + input_ids
+ attention_masks[i] = [0] * pad_length + attention_masks[i]
+
+ elif padding == "max_length":
+ if max_length is None:
+ raise ValueError(
+ "max_length must be specified when padding='max_length'"
+ )
+
+ for i, input_ids in enumerate(batch_input_ids):
+ pad_length = max_length - len(input_ids)
+ if pad_length > 0:
+ if self.padding_side == "right":
+ batch_input_ids[i] = (
+ input_ids + [self.pad_token_id] * pad_length
+ )
+ attention_masks[i] = attention_masks[i] + [0] * pad_length
+ else: # left padding
+ batch_input_ids[i] = [
+ self.pad_token_id
+ ] * pad_length + input_ids
+ attention_masks[i] = [0] * pad_length + attention_masks[i]
+
+ # Prepare result
+ result = {}
+
+ # Handle return tensor format
+ if return_tensors == "pt":
+ import torch
+
+ result["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long)
+ result["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
+ elif return_tensors == "np":
+ result["input_ids"] = np.array(batch_input_ids, dtype=np.int64)
+ result["attention_mask"] = np.array(attention_masks, dtype=np.int64)
+ elif return_tensors is None:
+ result["input_ids"] = batch_input_ids
+ result["attention_mask"] = attention_masks
+ else:
+ raise ValueError(
+ f"Unsupported return_tensors='{return_tensors}'. "
+ "Only 'pt' and 'np' are supported."
+ )
+
+ # If single input, return single sequences (not batched)
+ if len(text) == 1 and return_tensors is None:
+ result["input_ids"] = result["input_ids"][0]
+ result["attention_mask"] = result["attention_mask"][0]
+
+ return result
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index de80d1b79..06212a27f 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -320,7 +320,12 @@ class AxolotlInputConfig(
},
)
- gc_steps: int | None = None
+ gc_steps: int | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled)."
+ },
+ )
bf16: Literal["auto"] | bool | None = Field(
default="auto",
@@ -360,6 +365,12 @@ class AxolotlInputConfig(
"description": "Additional kwargs to pass to the trainer for gradient checkpointing"
},
)
+ activation_offloading: Literal["legacy", "disk"] | bool | None = Field(
+ default=False,
+ json_schema_extra={
+ "description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
+ },
+ )
unfrozen_parameters: list[str] | None = None
@@ -565,6 +576,13 @@ class AxolotlInputConfig(
},
)
+ tiled_mlp_use_original_mlp: bool | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
+ },
+ )
+
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = Field(
@@ -573,6 +591,12 @@ class AxolotlInputConfig(
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
},
)
+ deepcompile: bool | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Whether to use deepcompile for faster training with deepspeed"
+ },
+ )
fsdp: list[str] | None = Field(
default=None,
json_schema_extra={"description": "FSDP configuration"},
@@ -618,7 +642,12 @@ class AxolotlInputConfig(
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
},
)
-
+ tensor_parallel_size: int | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
+ },
+ )
special_tokens: SpecialTokensConfig | None = Field(
default=None,
json_schema_extra={
@@ -684,6 +713,7 @@ class AxolotlInputConfig(
"description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`"
},
)
+
save_steps: int | float | None = Field(
default=None,
json_schema_extra={
@@ -705,6 +735,13 @@ class AxolotlInputConfig(
save_total_limit: int | None = Field(
default=None, json_schema_extra={"description": "Checkpoints saved at a time"}
)
+ save_first_step: bool | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Whether to checkpoint a model after the first step of training. Defaults to False."
+ },
+ )
+
logging_steps: int | None = Field(
default=None, json_schema_extra={"description": "Logging frequency"}
)
@@ -730,6 +767,12 @@ class AxolotlInputConfig(
"description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz"
},
)
+ profiler_steps_start: int | None = Field(
+ default=0,
+ json_schema_extra={
+ "description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run."
+ },
+ )
include_tokens_per_second: bool | None = Field(
default=None,
json_schema_extra={
@@ -1143,72 +1186,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
- @model_validator(mode="before")
- @classmethod
- def check_fsdp_version(cls, data):
- fsdp_config = data.get("fsdp_config", {})
- if fsdp_config and str(data.get("fsdp_version")) != "2":
- LOG.info(
- "FSDP1 will be deprecated in an upcoming release of Axolotl."
- "We recommend that you use FSDP version 2 for better performance and compatibility. "
- "Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
- "For more details on migrating your config. "
- )
- return data
-
- @model_validator(mode="before")
- @classmethod
- def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
- fsdp_config = data.get("fsdp_config")
- if fsdp_config and data.get("fsdp_version") == 2:
- if fsdp_config.get("cpu_ram_efficient_loading") and (
- data.get("load_in_8bit") or data.get("load_in_4bit")
- ):
- raise ValueError(
- "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
- "set fsdp_version to 1, or disable cpu_ram_efficient_loading."
- )
- return data
-
- @model_validator(mode="before")
- @classmethod
- def check_fsdp2_base_model_quant_dpo(cls, data):
- if data.get("fsdp_version") == 2 and data.get("rl") in [
- RLType.DPO,
- RLType.KTO,
- RLType.ORPO,
- RLType.IPO,
- ]:
- if data.get("load_in_8bit") or data.get("load_in_4bit"):
- raise ValueError(
- "FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
- )
-
- return data
-
- @model_validator(mode="before")
- @classmethod
- def check_fsdp_version_in_fsdp_config(cls, data):
- if fsdp_config := data.get("fsdp_config"):
- if fsdp_config.get("fsdp_version"):
- LOG.warning(
- "Configuring `fsdp_version` in `fsdp_config` is deprecated. "
- "Please configure `fsdp_version` as a top-level field."
- )
- return data
-
- @model_validator(mode="before")
- @classmethod
- def check_fsdp_config_kwargs_prefix(cls, data):
- if fsdp_config := data.get("fsdp_config"):
- for key, _ in fsdp_config.items():
- if key.startswith("fsdp_"):
- LOG.warning_once(
- "Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
- "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
- )
- return data
-
@model_validator(mode="before")
@classmethod
def default_dataloader_opts(cls, data):
diff --git a/src/axolotl/utils/schemas/trl.py b/src/axolotl/utils/schemas/trl.py
index d1b18a56e..e4d17bc94 100644
--- a/src/axolotl/utils/schemas/trl.py
+++ b/src/axolotl/utils/schemas/trl.py
@@ -1,5 +1,7 @@
"""Pydantic models for TRL trainer configuration"""
+from typing import Literal
+
from pydantic import BaseModel, Field
@@ -27,6 +29,12 @@ class TRLConfig(BaseModel):
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training."},
)
+ vllm_mode: Literal["server", "colocate"] | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "VLLM mode to use, one of 'server' or 'colocate'"
+ },
+ )
vllm_server_host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host of the vLLM server to connect to."},
diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py
index 57959c4fa..292159bb8 100644
--- a/src/axolotl/utils/schemas/validation.py
+++ b/src/axolotl/utils/schemas/validation.py
@@ -1,8 +1,11 @@
"""Module with validation methods for config pydantic model."""
-# pylint: disable=too-many-lines
+# pylint: disable=too-many-boolean-expressions
+import json
import logging
+import tempfile
+from pathlib import Path
from pydantic import (
field_validator,
@@ -12,6 +15,8 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
+# pylint: disable=too-many-lines
+
LOG = logging.getLogger(__name__)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -748,43 +753,181 @@ class OptimizationValidationMixin:
@model_validator(mode="before")
@classmethod
- def check_fsdp_offload_w_8bit_optimizer(cls, data):
- if (
- data.get("fsdp")
- and "8bit" in data.get("optimizer", "")
- and data.get("fsdp_config")
- and data["fsdp_config"].get("fsdp_offload_params")
- and str(data["fsdp_config"].get("fsdp_version")) != "2"
- ):
- raise ValueError(
- f"FSDP Offload not compatible with {data.get('optimizer')}"
+ def check_fsdp_version(cls, data):
+ fsdp_config = data.get("fsdp_config", {})
+ if fsdp_config and str(data.get("fsdp_version")) != "2":
+ LOG.info(
+ "FSDP1 will be deprecated in an upcoming release of Axolotl."
+ "We recommend that you use FSDP version 2 for better performance and compatibility. "
+ "Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
+ "For more details on migrating your config. "
)
- if (
- data.get("fsdp")
- and "8bit" in data.get("optimizer", "")
- and data.get("fsdp_config")
- and str(data["fsdp_config"].get("fsdp_version")) == "2"
- ):
- if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
- # CUDA ops errors with bnb 8bit optimizer + FSDP2
+ return data
+
+ @model_validator(mode="after")
+ def check_fsdp2_base_model_quant_ram_efficient_loading(self):
+ fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
+ fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
+ load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
+ load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
+ if fsdp_config and fsdp_version == 2:
+ if fsdp_config.get("cpu_ram_efficient_loading") and (
+ load_in_8bit or load_in_4bit
+ ):
raise ValueError(
- f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
+ "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
+ "set fsdp_version to 1, or disable cpu_ram_efficient_loading."
+ )
+ return self
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_fsdp2_base_model_quant_rl(cls, data):
+ if data.get("fsdp_version") == 2 and data.get("rl") in [
+ RLType.DPO,
+ RLType.KTO,
+ RLType.ORPO,
+ RLType.IPO,
+ ]:
+ if data.get("load_in_8bit") or data.get("load_in_4bit"):
+ raise ValueError(
+ f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1."
)
return data
@model_validator(mode="before")
@classmethod
- def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
+ def check_fsdp_version_in_fsdp_config(cls, data):
+ if data.get("fsdp_config"):
+ if data.get("fsdp_config", {}).get("fsdp_version"):
+ LOG.warning(
+ "Configuring `fsdp_version` in `fsdp_config` is deprecated. "
+ "Please configure `fsdp_version` as a top-level field."
+ )
+ data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_fsdp_config_kwargs_prefix(cls, data):
+ if fsdp_config := data.get("fsdp_config"):
+ should_fix = False
+ for key, _ in fsdp_config.items():
+ if key.startswith("fsdp_"):
+ should_fix = True
+ LOG.warning_once(
+ "Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
+ "Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
+ )
+ if should_fix:
+ update_fsdp_config = {}
+ for key, value in fsdp_config.items():
+ if key.startswith("fsdp_") and key != "fsdp_version":
+ update_fsdp_config[key.replace("fsdp_", "")] = value
+ else:
+ update_fsdp_config[key] = value
+ data["fsdp_config"] = update_fsdp_config
+ return data
+
+ @model_validator(mode="after")
+ def check_fsdp_offload_w_8bit_optimizer(self):
if (
- data.get("fsdp_config")
- and data.get("save_safetensors")
- and data.get("fsdp_config")
- and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
+ hasattr(self, "fsdp_config")
+ and self.fsdp_config
+ and self.optimizer
+ and "8bit" in self.optimizer.value
+ and self.fsdp_config["offload_params"]
+ and str(self.fsdp_version) != "2"
+ ):
+ raise ValueError(
+ f"FSDP Offload not compatible with {str(self.optimizer.value)}"
+ )
+ return self
+
+ @model_validator(mode="after")
+ def check_fsdp2_w_8bit_optimizer(self):
+ if (
+ hasattr(self, "fsdp_config")
+ and self.fsdp_config
+ and self.optimizer
+ and "8bit" in self.optimizer.value
+ and str(self.fsdp_version) == "2"
+ ):
+ if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]:
+ # CUDA ops errors with bnb 8bit optimizer + FSDP2
+ raise ValueError(
+ f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead"
+ )
+
+ return self
+
+ @model_validator(mode="after")
+ def check_fsdp_sharded_state_dict_w_safetensors(self):
+ if (
+ hasattr(self, "fsdp_config")
+ and self.fsdp_config
+ and hasattr(self, "save_safetensors")
+ and self.save_safetensors
+ and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
+ and str(getattr(self, "fsdp_version", "1")) != "2"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
+ return self
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_tensor_parallel_size_update_ds_json(cls, data):
+ tensor_parallel_size = data.get("tensor_parallel_size")
+ if tensor_parallel_size is not None and tensor_parallel_size > 1:
+ if not data.get("deepspeed"):
+ raise ValueError(
+ "Tensor parallelism (TP) is only supported with DeepSpeed"
+ )
+ with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
+ ds_config = json.load(ds_fin)
+ should_save = False
+ if "tensor_parallel" not in ds_config:
+ ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
+ should_save = True
+ if (
+ "gather_16bit_weights_on_model_save"
+ not in ds_config["zero_optimization"]
+ ):
+ ds_config["zero_optimization"][
+ "gather_16bit_weights_on_model_save"
+ ] = True
+ should_save = True
+ if should_save:
+ temp_dir = tempfile.mkdtemp()
+ with open(
+ Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
+ ) as ds_fout:
+ json.dump(ds_config, ds_fout, indent=4)
+ data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
+
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_deepcompile(cls, data):
+ deepcompile = data.get("deepcompile")
+ if deepcompile:
+ if not data.get("deepspeed"):
+ raise ValueError("DeepCompile is only supported with DeepSpeed")
+ with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
+ ds_config = json.load(ds_fin)
+ if "compile" not in ds_config:
+ ds_config["compile"] = {"deepcompile": True}
+ temp_dir = tempfile.mkdtemp()
+ with open(
+ Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8"
+ ) as ds_fout:
+ json.dump(ds_config, ds_fout, indent=4)
+ data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json")
+
return data
@@ -924,12 +1067,47 @@ class ModelCompatibilityValidationMixin:
return self
@model_validator(mode="after")
- def check_offload_grad_checkpointing(self):
- if self.gradient_checkpointing and self.gradient_checkpointing == "unsloth":
+ def check_gradient_checkpointing_w_offload(self):
+ if self.gradient_checkpointing == "offload":
LOG.warning(
- "`unsloth` is deprecated for gradient_checkpointing, use `offload`"
+ "`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
)
- self.gradient_checkpointing = "offload"
+ self.gradient_checkpointing = True
+ if self.adapter and "lora" in self.adapter:
+ LOG.warning(
+ "offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation."
+ )
+ self.activation_offloading = "legacy"
+ else:
+ LOG.warning(
+ "`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
+ )
+ self.activation_offloading = True
+ if self.gradient_checkpointing == "offload_disk":
+ LOG.warning(
+ "`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
+ )
+ self.gradient_checkpointing = True
+ self.activation_offloading = "disk"
+ return self
+
+ @model_validator(mode="after")
+ def check_activation_offloading_w_lora(self):
+ if (
+ self.activation_offloading is True
+ and self.adapter
+ and "lora" in self.adapter
+ ):
+ LOG.warning(
+ "activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`"
+ )
+ self.activation_offloading = "legacy"
+ return self
+
+ @model_validator(mode="after")
+ def check_activation_offloading_wo_gc(self):
+ if self.activation_offloading and not self.gradient_checkpointing:
+ raise ValueError("activation_offloading requires gradient_checkpointing")
return self
@model_validator(mode="after")
@@ -1019,6 +1197,12 @@ class ComplexValidationMixin:
)
return self
+ @model_validator(mode="after")
+ def check_tensor_parallel_size(self):
+ if not self.tensor_parallel_size:
+ self.tensor_parallel_size = 1
+ return self
+
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
diff --git a/src/axolotl/utils/schemas/vllm.py b/src/axolotl/utils/schemas/vllm.py
index 0ae635589..518b8f62d 100644
--- a/src/axolotl/utils/schemas/vllm.py
+++ b/src/axolotl/utils/schemas/vllm.py
@@ -18,6 +18,10 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Tensor parallel size for VLLM"},
)
+ data_parallel_size: int | None = Field(
+ default=None,
+ json_schema_extra={"description": "Data parallel size for VLLM"},
+ )
gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 9224202e1..8371b2dd7 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
* cfg.num_epochs
* cfg.sequence_parallel_degree
+ * cfg.tensor_parallel_size
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
@@ -481,7 +482,10 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
math.floor(
- data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
+ data_loader_len
+ * cfg.num_epochs
+ * cfg.sequence_parallel_degree
+ * cfg.tensor_parallel_size
)
)
if cfg.dataloader_drop_last:
@@ -508,6 +512,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
len(train_dataset)
* cfg.num_epochs
* cfg.sequence_parallel_degree
+ * cfg.tensor_parallel_size
/ cfg.batch_size
)
)
@@ -546,7 +551,10 @@ def setup_deepspeed_env(cfg, stage=None):
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
- if int(os.environ.get("WORLD_SIZE", "1")) == 1:
+ if (
+ int(os.environ.get("WORLD_SIZE", "1")) == 1
+ and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
+ ):
os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
diff --git a/tests/conftest.py b/tests/conftest.py
index 24615fa22..9e1af318d 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,6 +22,8 @@ from huggingface_hub.errors import LocalEntryNotFoundError
from tokenizers import AddedToken
from transformers import AutoTokenizer
+from axolotl.utils.dict import DictDefault
+
from tests.hf_offline_utils import (
enable_hf_offline,
hf_offline_context,
@@ -539,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
return datasets.load_from_disk(ds_path)["train"]
+@pytest.fixture(name="min_base_cfg")
+def fixture_min_base_cfg():
+ return DictDefault(
+ base_model="HuggingFaceTB/SmolLM2-135M",
+ learning_rate=1e-3,
+ datasets=[
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ micro_batch_size=1,
+ gradient_accumulation_steps=1,
+ )
+
+
# # pylint: disable=redefined-outer-name,unused-argument
@pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py
index e66b8e009..0053b4d27 100644
--- a/tests/core/test_builders.py
+++ b/tests/core/test_builders.py
@@ -65,6 +65,7 @@ def fixture_base_cfg():
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1,
+ "tensor_parallel_size": 1,
# Dtype
"fp16": False,
"bf16": False,
diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py
index 790b34f3e..34e6c9644 100644
--- a/tests/e2e/integrations/test_cut_cross_entropy.py
+++ b/tests/e2e/integrations/test_cut_cross_entropy.py
@@ -44,6 +44,7 @@ def min_cfg(temp_dir):
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
@@ -98,6 +99,7 @@ class TestCutCrossEntropyIntegration:
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py
index 4734449fe..8743efb98 100644
--- a/tests/e2e/integrations/test_hooks.py
+++ b/tests/e2e/integrations/test_hooks.py
@@ -153,6 +153,7 @@ class TestPluginHooks:
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py
index 212450e89..1ac3b537e 100644
--- a/tests/e2e/integrations/test_kd.py
+++ b/tests/e2e/integrations/test_kd.py
@@ -67,6 +67,7 @@ def min_cfg(temp_dir):
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
+ "save_first_step": False,
}
diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py
index 6ab3d7ab8..b1f5befdd 100644
--- a/tests/e2e/integrations/test_liger.py
+++ b/tests/e2e/integrations/test_liger.py
@@ -50,6 +50,7 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -96,6 +97,7 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py
index 247ae3bac..dceecea9f 100644
--- a/tests/e2e/integrations/test_llm_compressor.py
+++ b/tests/e2e/integrations/test_llm_compressor.py
@@ -81,6 +81,7 @@ class TestLLMCompressorIntegration:
},
"save_compressed": save_compressed,
},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py
index 5593c7eb6..80098e684 100644
--- a/tests/e2e/multigpu/patched/test_sp.py
+++ b/tests/e2e/multigpu/patched/test_sp.py
@@ -69,6 +69,7 @@ class TestSequenceParallelism:
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py
index bdf5ada6b..cbdf8de96 100644
--- a/tests/e2e/multigpu/solo/test_flex.py
+++ b/tests/e2e/multigpu/solo/test_flex.py
@@ -61,6 +61,7 @@ class TestPackedFlex:
"max_steps": 2,
"use_tensorboard": True,
"save_strategy": "no",
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py
index c595d3fc0..d022ae2d9 100644
--- a/tests/e2e/multigpu/solo/test_grpo.py
+++ b/tests/e2e/multigpu/solo/test_grpo.py
@@ -141,6 +141,7 @@ def recursive_kill(process: subprocess.Popen):
os.kill(process.pid, 9)
+@pytest.mark.skip(reason="flaky vllm tests in modal")
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
@@ -222,6 +223,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -316,6 +318,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -408,6 +411,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py
index d6429cf63..4f86278ff 100644
--- a/tests/e2e/multigpu/test_eval.py
+++ b/tests/e2e/multigpu/test_eval.py
@@ -67,6 +67,7 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -138,6 +139,7 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py
index 3868d90f0..4a7b101a8 100644
--- a/tests/e2e/multigpu/test_gemma3.py
+++ b/tests/e2e/multigpu/test_gemma3.py
@@ -71,6 +71,7 @@ class TestMultiGPUGemma3:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py
index 7f9db12f3..aab14dcc4 100644
--- a/tests/e2e/multigpu/test_llama.py
+++ b/tests/e2e/multigpu/test_llama.py
@@ -69,6 +69,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -135,6 +136,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -210,6 +212,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -289,6 +292,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -365,6 +369,7 @@ class TestMultiGPULlama:
},
"use_tensorboard": True,
"seed": 42,
+ "save_first_step": False,
}
)
@@ -391,7 +396,10 @@ class TestMultiGPULlama:
@pytest.mark.parametrize(
"fsdp_state_dict_type",
- ["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
+ [
+ "FULL_STATE_DICT",
+ # "SHARDED_STATE_DICT", # not supported since intermediate checkpoints fail with fsdp1
+ ],
)
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
# pylint: disable=duplicate-code
@@ -413,7 +421,8 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
- "max_steps": 2,
+ "max_steps": 3,
+ "save_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
@@ -438,6 +447,7 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -516,6 +526,7 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if attention_backend == "flash":
@@ -597,10 +608,11 @@ class TestMultiGPULlama:
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
- "fsdp_state_dict_type": "SHARDED_STATE_DICT",
+ "fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -685,6 +697,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
+ "save_first_step": False,
**adapter,
}
)
@@ -707,7 +720,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
- temp_dir + "/runs", "train/train_loss", 2.4, "Train Loss (%s) is too high"
+ temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -761,6 +774,7 @@ class TestMultiGPULlama:
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
"seed": 42,
+ "save_first_step": False,
**adapter,
}
)
@@ -836,6 +850,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
+ "save_first_step": False,
**adapter,
}
)
@@ -904,6 +919,7 @@ class TestMultiGPULlama:
"save_safetensors": True,
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py
index 43a722b48..dd1422296 100644
--- a/tests/e2e/multigpu/test_ray.py
+++ b/tests/e2e/multigpu/test_ray.py
@@ -56,6 +56,7 @@ class TestMultiGPURay:
"use_tensorboard": True,
"use_ray": True,
"ray_num_workers": 2,
+ "save_first_step": False,
}
)
@@ -115,6 +116,7 @@ class TestMultiGPURay:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py
index 08b62accc..1824443e7 100644
--- a/tests/e2e/patched/test_4d_multipack_llama.py
+++ b/tests/e2e/patched/test_4d_multipack_llama.py
@@ -55,6 +55,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -102,6 +103,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py
index d494ed1eb..3d5b3dc56 100644
--- a/tests/e2e/patched/test_activation_checkpointing.py
+++ b/tests/e2e/patched/test_activation_checkpointing.py
@@ -69,6 +69,7 @@ class TestActivationCheckpointing:
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py
index ca8b21178..38099b220 100644
--- a/tests/e2e/patched/test_fa_xentropy.py
+++ b/tests/e2e/patched/test_fa_xentropy.py
@@ -62,6 +62,7 @@ class TestFAXentropyLlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py
index a593b0791..ef31b11c7 100644
--- a/tests/e2e/patched/test_falcon_samplepack.py
+++ b/tests/e2e/patched/test_falcon_samplepack.py
@@ -58,6 +58,7 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -99,6 +100,7 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_flattening.py b/tests/e2e/patched/test_flattening.py
index f77a1fbe5..fdaab558d 100644
--- a/tests/e2e/patched/test_flattening.py
+++ b/tests/e2e/patched/test_flattening.py
@@ -61,6 +61,7 @@ class TestFAFlattening:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py
index 1bbc82a38..a3fe591ee 100644
--- a/tests/e2e/patched/test_fused_llama.py
+++ b/tests/e2e/patched/test_fused_llama.py
@@ -53,6 +53,7 @@ class TestFusedLlama(unittest.TestCase):
"max_steps": 10,
"save_steps": 5,
"eval_steps": 5,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py
index d2dcc5e4b..ba5556a59 100644
--- a/tests/e2e/patched/test_llama_s2_attention.py
+++ b/tests/e2e/patched/test_llama_s2_attention.py
@@ -58,6 +58,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py
index 5df6bfecc..fdf6adbc6 100644
--- a/tests/e2e/patched/test_lora_llama_multipack.py
+++ b/tests/e2e/patched/test_lora_llama_multipack.py
@@ -55,6 +55,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -108,6 +109,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py
index 442089bae..bea0f9c68 100644
--- a/tests/e2e/patched/test_mistral_samplepack.py
+++ b/tests/e2e/patched/test_mistral_samplepack.py
@@ -56,6 +56,7 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -97,6 +98,7 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py
index 5f778660b..09e427abd 100644
--- a/tests/e2e/patched/test_mixtral_samplepack.py
+++ b/tests/e2e/patched/test_mixtral_samplepack.py
@@ -52,6 +52,7 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -90,6 +91,7 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py
index 5ea88b001..b90be23e4 100644
--- a/tests/e2e/patched/test_model_patches.py
+++ b/tests/e2e/patched/test_model_patches.py
@@ -45,6 +45,7 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -78,6 +79,7 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py
index d4f59a128..4769319ae 100644
--- a/tests/e2e/patched/test_peft_embeddings.py
+++ b/tests/e2e/patched/test_peft_embeddings.py
@@ -49,6 +49,7 @@ class TestLlamaPeftEmbeddings:
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py
index d241ce185..1f0ddd630 100644
--- a/tests/e2e/patched/test_phi_multipack.py
+++ b/tests/e2e/patched/test_phi_multipack.py
@@ -54,6 +54,7 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -105,6 +106,7 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py
index 363956733..54b8245ee 100644
--- a/tests/e2e/patched/test_resume.py
+++ b/tests/e2e/patched/test_resume.py
@@ -58,6 +58,7 @@ class TestResumeLlama:
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py
index 2b4d11b30..4a2c69d45 100644
--- a/tests/e2e/patched/test_sp.py
+++ b/tests/e2e/patched/test_sp.py
@@ -47,6 +47,7 @@ def fixture_cfg():
"special_tokens": {
"pad_token": "<|endoftext|>",
},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py
index 69171481c..2c8ee4eb0 100644
--- a/tests/e2e/patched/test_unsloth_qlora.py
+++ b/tests/e2e/patched/test_unsloth_qlora.py
@@ -62,6 +62,7 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -112,6 +113,7 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -167,6 +169,7 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py
index 279913713..76364fc0e 100644
--- a/tests/e2e/solo/test_flex.py
+++ b/tests/e2e/solo/test_flex.py
@@ -49,6 +49,7 @@ class TestPackedFlex(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py
index 7af550496..f6fcad841 100644
--- a/tests/e2e/solo/test_relora_llama.py
+++ b/tests/e2e/solo/test_relora_llama.py
@@ -65,6 +65,7 @@ class TestReLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py
index 7dfc4ae15..e4a47fb0a 100644
--- a/tests/e2e/test_deepseekv3.py
+++ b/tests/e2e/test_deepseekv3.py
@@ -67,6 +67,7 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -116,6 +117,7 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py
index 2cdb57689..a1df69535 100644
--- a/tests/e2e/test_dpo.py
+++ b/tests/e2e/test_dpo.py
@@ -56,6 +56,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -105,6 +106,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -154,6 +156,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -203,6 +206,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -251,6 +255,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -302,6 +307,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -370,6 +376,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py
index 9b65f8feb..e4a06ad14 100644
--- a/tests/e2e/test_embeddings_lr.py
+++ b/tests/e2e/test_embeddings_lr.py
@@ -48,6 +48,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -93,6 +94,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_evaluate.py b/tests/e2e/test_evaluate.py
index 6271bba28..977497e5e 100644
--- a/tests/e2e/test_evaluate.py
+++ b/tests/e2e/test_evaluate.py
@@ -36,6 +36,7 @@ class TestE2eEvaluate:
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py
index 4f88e740c..5be6efcf6 100644
--- a/tests/e2e/test_falcon.py
+++ b/tests/e2e/test_falcon.py
@@ -60,6 +60,7 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -115,6 +116,7 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -156,6 +158,7 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py
index 3f00a1384..ef38d028d 100644
--- a/tests/e2e/test_gemma3_text.py
+++ b/tests/e2e/test_gemma3_text.py
@@ -63,6 +63,7 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -113,6 +114,7 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py
index 2b180029c..1e6df0be9 100644
--- a/tests/e2e/test_llama.py
+++ b/tests/e2e/test_llama.py
@@ -45,6 +45,7 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -92,6 +93,7 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -136,6 +138,7 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -176,6 +179,7 @@ class TestLlama:
"batch_flattening": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py
index fdebf2173..bd5502300 100644
--- a/tests/e2e/test_llama_pretrain.py
+++ b/tests/e2e/test_llama_pretrain.py
@@ -53,6 +53,7 @@ class TestPretrainLlama:
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py
index ad4a83c6a..760759bca 100644
--- a/tests/e2e/test_llama_vision.py
+++ b/tests/e2e/test_llama_vision.py
@@ -54,6 +54,7 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py
index 301565302..7e0ff46cf 100644
--- a/tests/e2e/test_lora_llama.py
+++ b/tests/e2e/test_lora_llama.py
@@ -49,6 +49,7 @@ class TestLoraLlama(unittest.TestCase):
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py
index 1824619a6..73d3bdc26 100644
--- a/tests/e2e/test_mamba.py
+++ b/tests/e2e/test_mamba.py
@@ -51,6 +51,7 @@ class TestMamba(unittest.TestCase):
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py
index 5d9b8ba8c..f47f794e0 100644
--- a/tests/e2e/test_mistral.py
+++ b/tests/e2e/test_mistral.py
@@ -55,6 +55,7 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -95,6 +96,7 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py
index 761e59391..3fe2bf70f 100644
--- a/tests/e2e/test_mixtral.py
+++ b/tests/e2e/test_mixtral.py
@@ -61,6 +61,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -116,6 +117,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -170,6 +172,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -228,6 +231,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -273,6 +277,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py
index 53ef86022..1d233a201 100644
--- a/tests/e2e/test_optimizers.py
+++ b/tests/e2e/test_optimizers.py
@@ -55,6 +55,7 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ class TestCustomOptimizers(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adopt_adamw",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
@@ -146,6 +148,7 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
+ "save_first_step": False,
}
)
@@ -184,6 +187,7 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "constant",
"save_safetensors": True,
"max_steps": 10,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -232,6 +236,7 @@ class TestCustomOptimizers(unittest.TestCase):
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py
index cc2db72e0..aec9d95f8 100644
--- a/tests/e2e/test_packing_loss.py
+++ b/tests/e2e/test_packing_loss.py
@@ -48,6 +48,7 @@ class TestPackedLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py
index 88fda9191..ab3a63674 100644
--- a/tests/e2e/test_phi.py
+++ b/tests/e2e/test_phi.py
@@ -53,6 +53,7 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -102,6 +103,7 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_preprocess.py b/tests/e2e/test_preprocess.py
new file mode 100644
index 000000000..25f42e832
--- /dev/null
+++ b/tests/e2e/test_preprocess.py
@@ -0,0 +1,58 @@
+"""E2E Test the preprocess cli"""
+
+from pathlib import Path
+
+import yaml
+from accelerate.test_utils import execute_subprocess_async
+
+from axolotl.utils.dict import DictDefault
+
+AXOLOTL_ROOT = Path(__file__).parent.parent.parent
+
+
+class TestPreprocess:
+ """test cases for preprocess"""
+
+ def test_w_deepspeed(self, temp_dir):
+ """make sure preproces doesn't choke when using deepspeed in the config"""
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "Qwen/Qwen2.5-0.5B",
+ "sequence_len": 2048,
+ "val_set_size": 0.01,
+ "datasets": [
+ {
+ "path": "tatsu-lab/alpaca",
+ "type": "alpaca",
+ "split": "train[:10%]",
+ },
+ ],
+ "num_epochs": 1,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_torch_fused",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "bf16": "auto",
+ "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
+ "dataset_prepared_path": temp_dir + "/last_run_prepared",
+ }
+ )
+
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "axolotl",
+ "preprocess",
+ str(Path(temp_dir) / "config.yaml"),
+ ]
+ )
+
+ assert (Path(temp_dir) / "last_run_prepared").exists()
diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py
index abfe1b0c5..bd9eec48b 100644
--- a/tests/e2e/test_process_reward_model_smollm2.py
+++ b/tests/e2e/test_process_reward_model_smollm2.py
@@ -49,6 +49,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"use_tensorboard": True,
"special_tokens": {"pad_token": "<|endoftext|>"},
"seed": 42,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_profiler.py b/tests/e2e/test_profiler.py
new file mode 100644
index 000000000..ab273b981
--- /dev/null
+++ b/tests/e2e/test_profiler.py
@@ -0,0 +1,113 @@
+"""
+e2e gpu test for the pytorch profiler callback
+"""
+
+from pathlib import Path
+
+import pytest
+
+from axolotl.common.datasets import load_datasets
+from axolotl.train import train
+from axolotl.utils.config import normalize_config, validate_config
+from axolotl.utils.dict import DictDefault
+
+
+@pytest.fixture(name="profiler_base_cfg")
+def fixture_profiler_base_cfg():
+ cfg = DictDefault(
+ base_model="HuggingFaceTB/SmolLM2-135M",
+ tokenizer_type="AutoTokenizer",
+ sequence_len=1024,
+ load_in_8bit=True,
+ adapter="lora",
+ lora_r=8,
+ lora_alpha=16,
+ lora_dropout=0.05,
+ lora_target_linear=True,
+ val_set_size=0.02,
+ special_tokens={"pad_token": "<|endoftext|>"},
+ datasets=[
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ num_epochs=1,
+ micro_batch_size=2,
+ gradient_accumulation_steps=1,
+ learning_rate=0.00001,
+ optimizer="adamw_torch_fused",
+ lr_scheduler="cosine",
+ )
+ return cfg
+
+
+class TestProfiler:
+ """
+ test cases for the pytorch profiler callback
+ """
+
+ def test_profiler_saves(self, profiler_base_cfg, temp_dir):
+ cfg = profiler_base_cfg | DictDefault(
+ output_dir=temp_dir,
+ max_steps=5,
+ profiler_steps=3,
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "snapshot.pickle").exists()
+
+ def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
+ cfg = profiler_base_cfg | DictDefault(
+ output_dir=temp_dir,
+ max_steps=5,
+ profiler_steps=3,
+ profiler_steps_start=1,
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "snapshot.pickle").exists()
+
+ @pytest.mark.parametrize(
+ "profiler_steps_start",
+ [3, 5],
+ )
+ def test_profiler_saves_past_end(
+ self, profiler_base_cfg, temp_dir, profiler_steps_start
+ ):
+ cfg = profiler_base_cfg | DictDefault(
+ output_dir=temp_dir,
+ max_steps=5,
+ profiler_steps=3,
+ profiler_steps_start=profiler_steps_start,
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ assert (Path(temp_dir) / "snapshot.pickle").exists()
+
+ def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
+ cfg = profiler_base_cfg | DictDefault(
+ output_dir=temp_dir,
+ max_steps=5,
+ profiler_steps=3,
+ profiler_steps_start=6,
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ assert not (Path(temp_dir) / "snapshot.pickle").exists()
diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py
index ef726079d..139ae155a 100644
--- a/tests/e2e/test_qat.py
+++ b/tests/e2e/test_qat.py
@@ -57,6 +57,7 @@ class TestQATLlama:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -115,6 +116,7 @@ class TestQATLlama:
"weight_dtype": "int8",
"group_size": 8,
},
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py
index aa8b9f6c0..59267d14d 100644
--- a/tests/e2e/test_qwen.py
+++ b/tests/e2e/test_qwen.py
@@ -59,6 +59,7 @@ class TestE2eQwen:
"bf16": "auto",
"tf32": True,
"gradient_checkpointing": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py
index 5d52bcc86..82513f99f 100644
--- a/tests/e2e/test_reward_model_smollm2.py
+++ b/tests/e2e/test_reward_model_smollm2.py
@@ -58,6 +58,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_save_first_step.py b/tests/e2e/test_save_first_step.py
new file mode 100644
index 000000000..5bbd2302b
--- /dev/null
+++ b/tests/e2e/test_save_first_step.py
@@ -0,0 +1,102 @@
+"""
+E2E tests for relora llama
+"""
+
+import unittest
+from pathlib import Path
+
+import pytest
+
+from axolotl.common.datasets import load_datasets
+from axolotl.train import train
+from axolotl.utils.config import normalize_config, validate_config
+from axolotl.utils.dict import DictDefault
+
+from .utils import check_model_output_exists, with_temp_dir
+
+
+class TestSaveFirstStepCallback(unittest.TestCase):
+ """Test cases for save_first_step callback config."""
+
+ @with_temp_dir
+ def test_save_first_step(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "sequence_len": 512,
+ "val_set_size": 0.02,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_bnb_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "sample_packing": True,
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": True,
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
+
+ @with_temp_dir
+ def test_no_save_first_step(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "sequence_len": 512,
+ "val_set_size": 0.02,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_bnb_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "sample_packing": True,
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": False,
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ with pytest.raises(AssertionError):
+ check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py
index e98378f08..8f7a13aee 100644
--- a/tests/e2e/test_schedulers.py
+++ b/tests/e2e/test_schedulers.py
@@ -51,6 +51,7 @@ class TestCustomSchedulers(unittest.TestCase):
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
+ "save_first_step": False,
}
)
diff --git a/tests/prompt_strategies/test_chat_template_ds_schema_unification.py b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py
new file mode 100644
index 000000000..502efae4b
--- /dev/null
+++ b/tests/prompt_strategies/test_chat_template_ds_schema_unification.py
@@ -0,0 +1,75 @@
+"""
+Tests for chat template prompt strategy with schema unification for none fields
+"""
+
+import json
+
+import pytest
+from datasets import Dataset
+from transformers import AutoTokenizer
+
+from axolotl.prompt_strategies.chat_template import StrategyLoader
+from axolotl.utils.dict import DictDefault
+
+
+@pytest.fixture(name="messages_w_tools")
+def fixture_messages_w_tools():
+ jsons = """
+{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
+{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
+{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
+ """.strip().split(
+ "\n"
+ )
+ rows = [json.loads(row) for row in jsons]
+ return Dataset.from_list(rows)
+
+
+@pytest.fixture(name="qwen3_tokenizer")
+def qwen3_tokenizer_fixture(
+ download_qwen3_half_billion_model,
+): # pylint: disable=unused-argument
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
+
+ return tokenizer
+
+
+@pytest.fixture(name="qwen3_prompt_strategy")
+def qwen3_chat_template_strategy(qwen3_tokenizer):
+ cfg = DictDefault(
+ sequence_len=2048,
+ chat_template="qwen3",
+ eot_tokens=["<|im_end|>"],
+ )
+ ds_cfg = DictDefault(
+ type="chat_template",
+ )
+ load = StrategyLoader()
+ strat = load(qwen3_tokenizer, cfg, ds_cfg)
+ return strat
+
+
+class TestSchemaUnification:
+ """
+ Test class on handling null fields for tool calling
+ """
+
+ def test_schema_unification_single_prompt(
+ self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
+ ):
+ for row in messages_w_tools:
+ inputs = qwen3_prompt_strategy.tokenize_prompt(row)
+ decoded = qwen3_tokenizer.decode(inputs["input_ids"])
+ tool_call = decoded.split("")[-1].split("")[0]
+ assert '"message": null' not in tool_call
+ assert '"theta": null' not in tool_call
+
+ def test_schema_unification_batched(
+ self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
+ ):
+ rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)
+ for row in rows:
+ decoded = qwen3_tokenizer.decode(row["input_ids"])
+ tool_call = decoded.split("")[-1].split("")[0]
+ assert '"message": null' not in tool_call
+ assert '"theta": null' not in tool_call
diff --git a/tests/prompt_strategies/test_chat_templates_mistral.py b/tests/prompt_strategies/test_chat_templates_mistral.py
index f26ed0838..8e3f494b1 100644
--- a/tests/prompt_strategies/test_chat_templates_mistral.py
+++ b/tests/prompt_strategies/test_chat_templates_mistral.py
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
import pytest
if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
@@ -748,5 +750,100 @@ def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
assert "Not the same number of function calls and responses" in str(e)
+def test_magistral_tokenizer_call_method(
+ magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
+):
+ """Test the __call__ method behavior matches HuggingFace standards"""
+ from copy import deepcopy
+
+ import numpy as np
+ import torch
+
+ hf_tokenizer = deepcopy(llama3_tokenizer)
+ hf_tokenizer.pad_token = hf_tokenizer.eos_token
+
+ test_text = "Hello, how are you?"
+ batch_texts = ["Hello world", "How are you?"]
+
+ # Test single string with return_tensors=None
+ hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
+ mistral_result: dict[str, list[int]] = magistral_tokenizer(
+ test_text, return_tensors=None
+ )
+
+ assert isinstance(mistral_result, dict)
+ assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
+ assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
+ assert isinstance(
+ mistral_result["attention_mask"], type(hf_result["attention_mask"])
+ )
+ assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
+ assert np.all(mistral_result["attention_mask"])
+ assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
+
+ # Test single string with return_tensors='pt'
+ hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
+ mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
+ test_text, return_tensors="pt"
+ )
+
+ # Check structure and types
+ assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
+ assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
+
+ # Check shapes match (don't compare token dimension)
+ assert len(hf_result_pt["input_ids"].shape) == len(
+ mistral_result_pt["input_ids"].shape
+ )
+ assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
+ assert (
+ mistral_result_pt["attention_mask"].shape
+ == mistral_result_pt["input_ids"].shape
+ )
+ assert torch.all(mistral_result_pt["attention_mask"] == 1)
+
+ # Test batch input with padding
+ hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
+ batch_texts, return_tensors="pt", padding=True
+ )
+ mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
+ batch_texts, return_tensors="pt", padding=True
+ )
+
+ # Check batch behavior
+ assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
+ assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
+ assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
+ assert torch.any(
+ mistral_batch["attention_mask"][0] == 0
+ ) # padding in shorter sequence
+ assert torch.all(
+ mistral_batch["attention_mask"][1] == 1
+ ) # no padding in longer sequence
+
+ # Test numpy tensors
+ mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
+ test_text, return_tensors="np"
+ )
+ assert isinstance(mistral_result_np["input_ids"], np.ndarray)
+ assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
+
+ # Test consistency with encode()
+ encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
+ called: dict[str, torch.Tensor] = magistral_tokenizer(
+ test_text, return_tensors="pt"
+ )
+ assert encoded == called["input_ids"][0].tolist()
+
+ # Test Error handling
+ with pytest.raises(ValueError, match="Unsupported kwargs"):
+ magistral_tokenizer(test_text, unsupported_param=True)
+
+ with pytest.raises(
+ ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
+ ):
+ magistral_tokenizer(batch_texts, return_tensors="pt")
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_normalize_config.py b/tests/test_normalize_config.py
index 31d04fc64..658e06fcb 100644
--- a/tests/test_normalize_config.py
+++ b/tests/test_normalize_config.py
@@ -6,9 +6,9 @@ import unittest
from unittest.mock import patch
from axolotl.utils.config import (
- migrate_fsdp_config,
normalize_cfg_datasets,
normalize_config,
+ validate_config,
)
from axolotl.utils.dict import DictDefault
@@ -27,6 +27,13 @@ class NormalizeConfigTestCase(unittest.TestCase):
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "learning_rate": 0.0001,
}
)
@@ -97,7 +104,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
def test_migrate_fsdp_config(self):
"""Test basic FSDP config migration with and without fsdp_version"""
- cfg_with_version = DictDefault(
+ cfg_with_version = self._get_base_cfg() | DictDefault(
{
"fsdp_config": {
"fsdp_version": 2,
@@ -109,7 +116,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
}
)
- migrate_fsdp_config(cfg_with_version)
+ cfg_with_version = validate_config(cfg_with_version)
self.assertEqual(cfg_with_version.fsdp_version, 2)
self.assertEqual(
@@ -125,7 +132,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
self.assertNotIn("version", cfg_with_version.fsdp_config)
- cfg_without_version = DictDefault(
+ cfg_without_version = self._get_base_cfg() | DictDefault(
{
"fsdp_config": {
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
@@ -135,7 +142,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
}
)
- migrate_fsdp_config(cfg_without_version)
+ cfg_without_version = validate_config(cfg_without_version)
self.assertNotIn("fsdp_version", cfg_without_version)
self.assertEqual(
@@ -149,26 +156,25 @@ class NormalizeConfigTestCase(unittest.TestCase):
def test_migrate_fsdp_config_no_fsdp_config(self):
"""Test that function doesn't crash when no fsdp_config is present"""
- cfg = DictDefault({"some_other_config": "value"})
+ cfg = self._get_base_cfg()
- migrate_fsdp_config(cfg)
+ cfg = validate_config(cfg)
self.assertNotIn("fsdp_config", cfg)
self.assertNotIn("fsdp_version", cfg)
- self.assertEqual(cfg.some_other_config, "value")
def test_migrate_fsdp_config_empty_fsdp_config(self):
"""Test migration with empty fsdp_config"""
- cfg = DictDefault({"fsdp_config": {}})
+ cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}})
- migrate_fsdp_config(cfg)
+ cfg = validate_config(cfg)
self.assertNotIn("fsdp_version", cfg)
self.assertEqual(cfg.fsdp_config, {})
def test_migrate_fsdp_config_mixed_keys(self):
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
- cfg = DictDefault(
+ cfg = self._get_base_cfg() | DictDefault(
{
"fsdp_config": {
"fsdp_version": 1,
@@ -180,7 +186,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
}
)
- migrate_fsdp_config(cfg)
+ cfg = validate_config(cfg)
self.assertEqual(cfg.fsdp_version, 1)
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
diff --git a/tests/test_train.py b/tests/test_train.py
new file mode 100644
index 000000000..2c29b58ee
--- /dev/null
+++ b/tests/test_train.py
@@ -0,0 +1,39 @@
+"""Test for batch size calculation for multi-gpu training."""
+
+import pytest
+
+from axolotl.utils.config import normalize_config, validate_config
+from axolotl.utils.dict import DictDefault
+
+
+@pytest.fixture(name="train_base_cfg")
+def fixture_train_base_cfg(min_base_cfg):
+ return (
+ DictDefault(
+ micro_batch_size=2,
+ gradient_accumulation_steps=4,
+ sequence_len=2048,
+ sample_packing=True,
+ num_epochs=1,
+ )
+ | min_base_cfg
+ )
+
+
+class TestTrain:
+ """test class for train related tests"""
+
+ @pytest.mark.parametrize(
+ "world_size, expected_batch_size",
+ [
+ (1, 8),
+ (4, 32),
+ ],
+ )
+ def test_batch_size_ddp(
+ self, train_base_cfg, monkeypatch, world_size, expected_batch_size
+ ):
+ monkeypatch.setenv("WORLD_SIZE", str(world_size))
+ cfg = validate_config(train_base_cfg)
+ normalize_config(cfg)
+ assert cfg.batch_size == expected_batch_size
diff --git a/tests/utils/schemas/validation/test_activation_offloading.py b/tests/utils/schemas/validation/test_activation_offloading.py
new file mode 100644
index 000000000..92ac8f45c
--- /dev/null
+++ b/tests/utils/schemas/validation/test_activation_offloading.py
@@ -0,0 +1,91 @@
+"""Test for config validation for activation offloading."""
+
+from axolotl.utils.config import validate_config
+from axolotl.utils.dict import DictDefault
+
+
+class TestActivationOffloading:
+ """
+ Test cases for activation offloading schema validation
+ """
+
+ def test_gc_converts_offload_wo_lora(self, min_base_cfg):
+ cfg = (
+ DictDefault(
+ gradient_checkpointing="offload",
+ )
+ | min_base_cfg
+ )
+
+ cfg = validate_config(cfg)
+ assert cfg.gradient_checkpointing is True
+ assert cfg.activation_offloading is True
+
+ def test_gc_converts_offload_w_lora(self, min_base_cfg):
+ cfg = (
+ DictDefault(
+ gradient_checkpointing="offload",
+ adapter="lora",
+ )
+ | min_base_cfg
+ )
+
+ cfg = validate_config(cfg)
+ assert cfg.gradient_checkpointing is True
+ assert cfg.activation_offloading == "legacy"
+
+ def test_gc_converts_offload_w_qlora(self, min_base_cfg):
+ cfg = (
+ DictDefault(
+ gradient_checkpointing="offload",
+ adapter="qlora",
+ load_in_4bit=True,
+ )
+ | min_base_cfg
+ )
+
+ cfg = validate_config(cfg)
+ assert cfg.gradient_checkpointing is True
+ assert cfg.activation_offloading == "legacy"
+
+ def test_ac_impl_changes_w_lora(self, min_base_cfg):
+ cfg = (
+ DictDefault(
+ gradient_checkpointing=True,
+ activation_offloading=True,
+ adapter="lora",
+ )
+ | min_base_cfg
+ )
+
+ cfg = validate_config(cfg)
+ assert cfg.gradient_checkpointing is True
+ assert cfg.activation_offloading == "legacy"
+
+ def test_ac_impl_changes_w_qlora(self, min_base_cfg):
+ cfg = (
+ DictDefault(
+ gradient_checkpointing=True,
+ activation_offloading=True,
+ adapter="qlora",
+ load_in_4bit=True,
+ )
+ | min_base_cfg
+ )
+
+ cfg = validate_config(cfg)
+ assert cfg.gradient_checkpointing is True
+ assert cfg.activation_offloading == "legacy"
+
+ def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
+ cfg = (
+ DictDefault(
+ gradient_checkpointing=True,
+ activation_offloading=True,
+ )
+ | min_base_cfg
+ )
+
+ cfg = validate_config(cfg)
+ assert cfg.gradient_checkpointing is True
+ assert cfg.activation_offloading is True
diff --git a/tests/utils/schemas/validation/test_fsdp.py b/tests/utils/schemas/validation/test_fsdp.py
new file mode 100644
index 000000000..67f4a5cf9
--- /dev/null
+++ b/tests/utils/schemas/validation/test_fsdp.py
@@ -0,0 +1,139 @@
+"""
+tests for pydantic fsdp validation
+"""
+
+# pylint: disable=too-many-boolean-expressions
+import pytest
+
+from axolotl.utils.config import validate_config
+from axolotl.utils.dict import DictDefault
+
+
+class TestFSDPValidation:
+ """
+ test class for pydantic fsdp validation
+ """
+
+ def test_fsdp_version_in_fsdp_config(self, min_base_cfg):
+ cfg = min_base_cfg | DictDefault(
+ fsdp_config={
+ "fsdp_version": 2,
+ },
+ )
+ cfg = validate_config(
+ cfg,
+ )
+ assert cfg.fsdp_version == 2
+ assert cfg.fsdp_config.fsdp_version is None
+
+ def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
+ cfg = min_base_cfg | DictDefault(
+ fsdp_config={
+ "fsdp_state_dict_type": "SHARDED_STATE_DICT",
+ },
+ save_safetensors=True,
+ )
+ with pytest.raises(
+ ValueError,
+ match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
+ ):
+ validate_config(cfg)
+
+ # test w/o prefix too
+ cfg = min_base_cfg | DictDefault(
+ fsdp_config={
+ "state_dict_type": "SHARDED_STATE_DICT",
+ },
+ save_safetensors=True,
+ )
+ with pytest.raises(
+ ValueError,
+ match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
+ ):
+ validate_config(cfg)
+
+ def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
+ cfg = min_base_cfg | DictDefault(
+ fsdp_config={
+ "offload_params": True,
+ },
+ optimizer="adamw_8bit",
+ fsdp_version=1,
+ )
+ with pytest.raises(
+ ValueError, match="FSDP Offload not compatible with adamw_8bit"
+ ):
+ validate_config(cfg)
+
+ def test_fsdp2_w_8bit_optim(self, min_base_cfg):
+ cfg = min_base_cfg | DictDefault(
+ fsdp_config={
+ "offload_params": True,
+ },
+ optimizer="adamw_8bit",
+ fsdp_version=2,
+ )
+ with pytest.raises(
+ ValueError,
+ match="FSDP2 not compatible with adamw_8bit, use `adamw_torch_8bit` instead",
+ ):
+ validate_config(cfg)
+
+ def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg):
+ cfg = min_base_cfg | DictDefault(
+ load_in_8bit=True,
+ adapter="lora",
+ fsdp_config={
+ "cpu_ram_efficient_loading": True,
+ },
+ fsdp_version=2,
+ )
+ with pytest.raises(
+ ValueError,
+ match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.",
+ ):
+ validate_config(cfg)
+
+ def test_fsdp_prefixes_removed(self, min_base_cfg):
+ cfg = min_base_cfg | DictDefault(
+ fsdp_config={
+ "fsdp_version": 2,
+ "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
+ "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
+ "fsdp_reshard_after_forward": True,
+ }
+ )
+ cfg = validate_config(cfg)
+ assert cfg.fsdp_version == 2
+ assert cfg.fsdp_config.fsdp_version is None
+ for keys in cfg.fsdp_config.keys():
+ assert not keys.startswith("fsdp_")
+ assert cfg.fsdp_config.auto_wrap_policy == "TRANSFORMER_BASED_WRAP"
+ assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
+ assert cfg.fsdp_config.reshard_after_forward is True
+
+ @pytest.mark.parametrize(
+ "rl",
+ [
+ "dpo",
+ "kto",
+ "orpo",
+ "ipo",
+ ],
+ )
+ def test_fsdp2_dpo(self, min_base_cfg, rl):
+ cfg = min_base_cfg | DictDefault(
+ fsdp_version=2,
+ fsdp_config={
+ "reshard_after_forward": True,
+ },
+ rl=rl,
+ load_in_8bit=True,
+ adapter="lora",
+ remove_unused_columns=False,
+ )
+ with pytest.raises(
+ ValueError,
+ match="FSDP2 does not support load_in_8bit or load_in_4bit with ",
+ ):
+ validate_config(cfg)