diff --git a/examples/cloud/modal.yaml b/examples/cloud/modal.yaml
index 195031494..bbe8785f1 100644
--- a/examples/cloud/modal.yaml
+++ b/examples/cloud/modal.yaml
@@ -26,3 +26,5 @@ timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml
index 4a30e9a77..da2777270 100644
--- a/examples/cohere/command-r-7b-qlora.yml
+++ b/examples/cohere/command-r-7b-qlora.yml
@@ -35,7 +35,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -56,3 +55,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
index 2c0495ced..1a051b98b 100644
--- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
index de9c956e0..807342641 100644
--- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml
index 0ed97db36..78bf6b179 100644
--- a/examples/deepseek-v2/fft-fsdp-16b.yaml
+++ b/examples/deepseek-v2/fft-fsdp-16b.yaml
@@ -55,3 +55,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
index 34dbeaafe..da1d9aefd 100644
--- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml
+++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
@@ -79,3 +79,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml
index dc0051bd5..9d92e8662 100644
--- a/examples/devstral/devstral-small-qlora.yml
+++ b/examples/devstral/devstral-small-qlora.yml
@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
index 1dd901154..484c31fec 100644
--- a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
index 24dc7cae3..dea2a6e6d 100644
--- a/examples/falcon-h1/falcon-h1-1b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
@@ -46,7 +46,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -69,3 +68,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
index 43eb1967b..b187efbf6 100644
--- a/examples/falcon-h1/falcon-h1-34b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
index 00929bbf0..4d981ad95 100644
--- a/examples/falcon-h1/falcon-h1-3b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
index e2640de7b..5ee13facd 100644
--- a/examples/falcon-h1/falcon-h1-500m-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
index 183e423b5..4b665c3cd 100644
--- a/examples/falcon-h1/falcon-h1-7b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
@@ -69,3 +69,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml
index cb96a32c1..68d213fad 100644
--- a/examples/gemma2/qlora.yml
+++ b/examples/gemma2/qlora.yml
@@ -60,3 +60,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml
index ce01a4572..624ebdcd2 100644
--- a/examples/gemma2/reward-model.yaml
+++ b/examples/gemma2/reward-model.yaml
@@ -50,3 +50,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml
index 217c887aa..99921770d 100644
--- a/examples/gemma3/gemma-3-1b-qlora.yml
+++ b/examples/gemma3/gemma-3-1b-qlora.yml
@@ -66,3 +66,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml
index d78559ae3..025cb9240 100644
--- a/examples/gemma3/gemma-3-4b-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-qlora.yml
@@ -60,3 +60,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml
index 183eb88e8..e9e606b69 100644
--- a/examples/gemma3/gemma-3-4b-vision-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml
@@ -62,3 +62,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml
index 86d9b43f8..8973cedd4 100644
--- a/examples/glm4/qlora-32b.yaml
+++ b/examples/glm4/qlora-32b.yaml
@@ -60,3 +60,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml
index 2cb0eea41..494154886 100644
--- a/examples/jamba/qlora.yaml
+++ b/examples/jamba/qlora.yaml
@@ -54,3 +54,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml
index d13ce6483..64db8f2ff 100644
--- a/examples/jamba/qlora_deepspeed.yaml
+++ b/examples/jamba/qlora_deepspeed.yaml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml
index 6badaba19..fda30e2d2 100644
--- a/examples/jamba/qlora_fsdp_large.yaml
+++ b/examples/jamba/qlora_fsdp_large.yaml
@@ -64,3 +64,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/lfm2/lfm2-350m-fft.yaml
index 95961557e..74c90c1e1 100644
--- a/examples/lfm2/lfm2-350m-fft.yaml
+++ b/examples/lfm2/lfm2-350m-fft.yaml
@@ -46,3 +46,5 @@ evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml
index 86b1b6a21..c44cd2230 100644
--- a/examples/llama-2/fft_optimized.yml
+++ b/examples/llama-2/fft_optimized.yml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml
index 0f1b34016..580fabdf8 100644
--- a/examples/llama-2/gptq-lora.yml
+++ b/examples/llama-2/gptq-lora.yml
@@ -64,3 +64,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml
index a76a792ae..a44e261be 100644
--- a/examples/llama-2/lisa.yml
+++ b/examples/llama-2/lisa.yml
@@ -60,3 +60,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml
index 22dbf2d99..085627f63 100644
--- a/examples/llama-2/loftq.yml
+++ b/examples/llama-2/loftq.yml
@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml
index 679aed3a9..759fce044 100644
--- a/examples/llama-2/lora.yml
+++ b/examples/llama-2/lora.yml
@@ -52,3 +52,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml
index a42eabd4b..3bf30120b 100644
--- a/examples/llama-2/qlora-fsdp.yml
+++ b/examples/llama-2/qlora-fsdp.yml
@@ -67,3 +67,5 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml
index de65928bc..09596c71e 100644
--- a/examples/llama-2/qlora.yml
+++ b/examples/llama-2/qlora.yml
@@ -53,3 +53,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml
index e0a5f7068..ca8b14a1c 100644
--- a/examples/llama-2/relora.yml
+++ b/examples/llama-2/relora.yml
@@ -58,3 +58,5 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml
index 2b0ae2c70..64d749b5a 100644
--- a/examples/llama-3-vision/lora-11b.yaml
+++ b/examples/llama-3-vision/lora-11b.yaml
@@ -57,3 +57,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml
index 5d979c96c..08d8ee5c1 100644
--- a/examples/llama-3/3b-qat-fsdp2.yaml
+++ b/examples/llama-3/3b-qat-fsdp2.yaml
@@ -77,3 +77,5 @@ fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml
index eccfa6d8c..e2808935f 100644
--- a/examples/llama-3/fft-8b-liger-fsdp.yaml
+++ b/examples/llama-3/fft-8b-liger-fsdp.yaml
@@ -72,3 +72,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml
index fdae3e6c4..2dfe6d492 100644
--- a/examples/llama-3/fft-8b.yaml
+++ b/examples/llama-3/fft-8b.yaml
@@ -42,3 +42,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml
index 51f1c768b..10ab2a320 100644
--- a/examples/llama-3/instruct-dpo-lora-8b.yml
+++ b/examples/llama-3/instruct-dpo-lora-8b.yml
@@ -71,3 +71,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml
index acab862f6..83b7f9a37 100644
--- a/examples/llama-3/instruct-lora-8b.yml
+++ b/examples/llama-3/instruct-lora-8b.yml
@@ -64,3 +64,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml
index 10e9747cb..b20dbad84 100644
--- a/examples/llama-3/lora-1b-deduplicate-dpo.yml
+++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml
@@ -83,3 +83,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml
index 630ec92f6..67e518184 100644
--- a/examples/llama-3/lora-1b-deduplicate-sft.yml
+++ b/examples/llama-3/lora-1b-deduplicate-sft.yml
@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml
index a2d07ca49..92a948c2e 100644
--- a/examples/llama-3/lora-1b-kernels.yml
+++ b/examples/llama-3/lora-1b-kernels.yml
@@ -65,3 +65,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml
index bb23164eb..178a1fb89 100644
--- a/examples/llama-3/lora-1b-ray.yml
+++ b/examples/llama-3/lora-1b-ray.yml
@@ -64,3 +64,5 @@ special_tokens:
use_ray: true
ray_num_workers: 4
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
index 769dd32e6..c4ce3eb0f 100644
--- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml
+++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
@@ -63,3 +63,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml
index acc17e21f..82085483f 100644
--- a/examples/llama-3/lora-1b.yml
+++ b/examples/llama-3/lora-1b.yml
@@ -60,3 +60,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml
index ad50cd38a..c39389755 100644
--- a/examples/llama-3/lora-8b.yml
+++ b/examples/llama-3/lora-8b.yml
@@ -57,3 +57,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml
index 89a51ea68..f156e23d3 100644
--- a/examples/llama-3/qlora-1b-kto.yaml
+++ b/examples/llama-3/qlora-1b-kto.yaml
@@ -61,3 +61,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml
index 5c8fe6628..6b76ea8d9 100644
--- a/examples/llama-3/qlora-1b.yml
+++ b/examples/llama-3/qlora-1b.yml
@@ -62,3 +62,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml
index 2b7d51925..1ee922b59 100644
--- a/examples/llama-3/qlora-fsdp-405b.yaml
+++ b/examples/llama-3/qlora-fsdp-405b.yaml
@@ -60,3 +60,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml
index 412b6721c..5edd8353a 100644
--- a/examples/llama-3/qlora-fsdp-70b.yaml
+++ b/examples/llama-3/qlora-fsdp-70b.yaml
@@ -69,3 +69,5 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml
index 4cc9fc3db..a674eca27 100644
--- a/examples/llama-3/qlora.yml
+++ b/examples/llama-3/qlora.yml
@@ -54,3 +54,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml
index 1bbb88028..8577a19d2 100644
--- a/examples/llama-3/sparse-finetuning.yaml
+++ b/examples/llama-3/sparse-finetuning.yaml
@@ -75,3 +75,5 @@ llmcompressor:
]
start: 0
save_compressed: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
index 2be94f4ef..d4a038e11 100644
--- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
index eeae872a6..bea10d979 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
@@ -90,3 +90,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
index 17ad70634..737d93812 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
@@ -83,3 +83,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
index eff708e4d..390be5af7 100644
--- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
@@ -86,3 +86,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
index 9a411883e..b319349c4 100644
--- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
+++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
@@ -84,3 +84,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml
index 20352f81e..6be3988ef 100644
--- a/examples/llama-4/scout-qlora-single-h100-flex.yaml
+++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml
@@ -82,3 +82,5 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
index 9fbd34107..a67936cf1 100644
--- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
+++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
@@ -87,3 +87,5 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml
index 5198c8e74..a4bac8987 100644
--- a/examples/llava/lora-7b.yaml
+++ b/examples/llava/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml
index b10e8baf6..b23d2309a 100644
--- a/examples/magistral/magistral-small-fsdp-qlora.yaml
+++ b/examples/magistral/magistral-small-fsdp-qlora.yaml
@@ -70,3 +70,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml
index e3e746f22..f0fce014f 100644
--- a/examples/magistral/magistral-small-qlora.yaml
+++ b/examples/magistral/magistral-small-qlora.yaml
@@ -61,3 +61,5 @@ flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml
index 3d4583932..2261bd215 100644
--- a/examples/mamba/config.yml
+++ b/examples/mamba/config.yml
@@ -48,3 +48,5 @@ weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml
index f626a92a1..e9bcbb7d6 100644
--- a/examples/mistral/bigstral-ds-zero3.yaml
+++ b/examples/mistral/bigstral-ds-zero3.yaml
@@ -53,3 +53,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml
index 15edffb44..8c4d80f79 100644
--- a/examples/mistral/config.yml
+++ b/examples/mistral/config.yml
@@ -43,3 +43,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml
index e6f46affb..d54c3e30b 100644
--- a/examples/mistral/lora-mps.yml
+++ b/examples/mistral/lora-mps.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml
index 9af4274fd..161255468 100644
--- a/examples/mistral/lora.yml
+++ b/examples/mistral/lora.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/mistral-dpo-qlora.yml
index af707973f..8d0378690 100644
--- a/examples/mistral/mistral-dpo-qlora.yml
+++ b/examples/mistral/mistral-dpo-qlora.yml
@@ -80,3 +80,5 @@ weight_decay: 0.0
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml
index e234b19a2..cec958c54 100644
--- a/examples/mistral/mistral-qlora-fsdp.yml
+++ b/examples/mistral/mistral-qlora-fsdp.yml
@@ -74,3 +74,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml
index 6c0212b7c..f37dc09fa 100644
--- a/examples/mistral/mistral-qlora-orpo.yml
+++ b/examples/mistral/mistral-qlora-orpo.yml
@@ -69,3 +69,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml
index 3e3b45862..4a492c595 100644
--- a/examples/mistral/mistral-small-3.1-24B-lora.yml
+++ b/examples/mistral/mistral-small-3.1-24B-lora.yml
@@ -56,3 +56,5 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
index af6ba5a76..64ef9930c 100644
--- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
@@ -72,3 +72,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml
index b1843a138..c8d0a2711 100644
--- a/examples/mistral/mixtral-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-qlora-fsdp.yml
@@ -77,3 +77,5 @@ fsdp_config:
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml
index 4c256420c..5be9b4db8 100644
--- a/examples/mistral/mixtral.yml
+++ b/examples/mistral/mixtral.yml
@@ -81,3 +81,5 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml
index 25e1d7155..100e4464f 100644
--- a/examples/mistral/mixtral_22.yml
+++ b/examples/mistral/mixtral_22.yml
@@ -51,3 +51,5 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml
index 607e33701..08df36e15 100644
--- a/examples/mistral/qlora.yml
+++ b/examples/mistral/qlora.yml
@@ -64,3 +64,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml
index 9bcbbeee0..57f65d966 100644
--- a/examples/orpheus/finetune.yml
+++ b/examples/orpheus/finetune.yml
@@ -50,3 +50,5 @@ weight_decay: 0.05
special_tokens:
pad_token:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml
index ad4ce9cd4..9f3bbdf53 100644
--- a/examples/phi/lora-3.5.yaml
+++ b/examples/phi/lora-3.5.yaml
@@ -63,3 +63,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 4
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml
index 1562a7353..fc6d649d7 100644
--- a/examples/phi/phi-ft.yml
+++ b/examples/phi/phi-ft.yml
@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml
index 4cd53db97..ccd92c817 100644
--- a/examples/phi/phi-qlora.yml
+++ b/examples/phi/phi-qlora.yml
@@ -60,3 +60,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml
index ca733cc71..853250ccb 100644
--- a/examples/phi/phi2-ft.yml
+++ b/examples/phi/phi2-ft.yml
@@ -57,3 +57,5 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml
index d0d14fea6..130298bc0 100644
--- a/examples/phi/phi3-ft-fsdp.yml
+++ b/examples/phi/phi3-ft-fsdp.yml
@@ -71,3 +71,5 @@ fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml
index 17c48da6f..42b87e8d0 100644
--- a/examples/phi/phi3-ft.yml
+++ b/examples/phi/phi3-ft.yml
@@ -59,3 +59,5 @@ warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml
index 6ad0a5e99..ea769d202 100644
--- a/examples/pixtral/lora-12b.yml
+++ b/examples/pixtral/lora-12b.yml
@@ -55,3 +55,5 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml
index e8932b968..8ea608199 100644
--- a/examples/qwen2-vl/lora-7b.yaml
+++ b/examples/qwen2-vl/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml
index bd896c2b3..69a74ae4a 100644
--- a/examples/qwen2/dpo.yaml
+++ b/examples/qwen2/dpo.yaml
@@ -54,3 +54,5 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml
index 4afa24f3c..af188f75d 100644
--- a/examples/qwen2/prm.yaml
+++ b/examples/qwen2/prm.yaml
@@ -55,3 +55,5 @@ eval_steps: 100
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml
index ed2670ab6..861ce5517 100644
--- a/examples/qwen2/qlora-fsdp.yaml
+++ b/examples/qwen2/qlora-fsdp.yaml
@@ -67,3 +67,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml
index 822407a1f..1854b8216 100644
--- a/examples/qwen2/reward-model.yaml
+++ b/examples/qwen2/reward-model.yaml
@@ -26,7 +26,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
-
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
@@ -50,3 +49,5 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml
index 25d02805f..13a97dec3 100644
--- a/examples/qwen2_5-vl/lora-7b.yaml
+++ b/examples/qwen2_5-vl/lora-7b.yaml
@@ -53,3 +53,5 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml
index 45a4395ac..1f148ece5 100644
--- a/examples/qwen3/32b-qlora.yaml
+++ b/examples/qwen3/32b-qlora.yaml
@@ -67,3 +67,5 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml
index 6832b6af7..e4d0ed4fb 100644
--- a/examples/qwen3/8b-qat-fsdp2.yml
+++ b/examples/qwen3/8b-qat-fsdp2.yml
@@ -76,3 +76,5 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml
index dc3377b4f..762f9648d 100644
--- a/examples/qwen3/qlora-fsdp.yaml
+++ b/examples/qwen3/qlora-fsdp.yaml
@@ -66,3 +66,5 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
+
+# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py
index 4df010040..d3a3b3242 100644
--- a/src/axolotl/core/builders/base.py
+++ b/src/axolotl/core/builders/base.py
@@ -36,6 +36,7 @@ from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
+ SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -135,6 +136,8 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
+ if self.cfg.save_first_step:
+ callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))
diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py
index 5f804d6af..bb777fc90 100644
--- a/src/axolotl/utils/callbacks/__init__.py
+++ b/src/axolotl/utils/callbacks/__init__.py
@@ -64,7 +64,7 @@ class SaveBetterTransformerModelCallback(
state: TrainerState,
control: TrainerControl,
**kwargs,
- ):
+ ) -> TrainerControl:
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
@@ -100,11 +100,11 @@ class GPUStatsCallback(
def on_step_end(
self,
- args: TrainingArguments,
+ args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**kwargs,
- ):
+ ) -> TrainerControl:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
@@ -116,18 +116,17 @@ class LossWatchDogCallback(TrainerCallback):
def __init__(self, cfg):
self.cfg = cfg
- self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
- _args: TrainingArguments,
+ args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**_kwargs,
- ):
+ ) -> TrainerControl:
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
@@ -141,6 +140,21 @@ class LossWatchDogCallback(TrainerCallback):
return control
+class SaveModelOnFirstStepCallback(TrainerCallback):
+ """Callback to save the model on the first step of training if enabled"""
+
+ def on_step_end(
+ self,
+ args: TrainingArguments, # pylint: disable=unused-argument
+ state: TrainerState,
+ control: TrainerControl,
+ **_kwargs,
+ ) -> TrainerControl:
+ if state.global_step == 1:
+ control.should_save = True
+ return control
+
+
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index 909fd637c..e20cdaf47 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -706,6 +706,7 @@ class AxolotlInputConfig(
"description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`"
},
)
+
save_steps: int | float | None = Field(
default=None,
json_schema_extra={
@@ -727,6 +728,13 @@ class AxolotlInputConfig(
save_total_limit: int | None = Field(
default=None, json_schema_extra={"description": "Checkpoints saved at a time"}
)
+ save_first_step: bool | None = Field(
+ default=None,
+ json_schema_extra={
+ "description": "Whether to checkpoint a model after the first step of training. Defaults to False."
+ },
+ )
+
logging_steps: int | None = Field(
default=None, json_schema_extra={"description": "Logging frequency"}
)
diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py
index 790b34f3e..34e6c9644 100644
--- a/tests/e2e/integrations/test_cut_cross_entropy.py
+++ b/tests/e2e/integrations/test_cut_cross_entropy.py
@@ -44,6 +44,7 @@ def min_cfg(temp_dir):
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
@@ -98,6 +99,7 @@ class TestCutCrossEntropyIntegration:
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py
index 4734449fe..8743efb98 100644
--- a/tests/e2e/integrations/test_hooks.py
+++ b/tests/e2e/integrations/test_hooks.py
@@ -153,6 +153,7 @@ class TestPluginHooks:
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py
index 212450e89..1ac3b537e 100644
--- a/tests/e2e/integrations/test_kd.py
+++ b/tests/e2e/integrations/test_kd.py
@@ -67,6 +67,7 @@ def min_cfg(temp_dir):
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
+ "save_first_step": False,
}
diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py
index 6ab3d7ab8..b1f5befdd 100644
--- a/tests/e2e/integrations/test_liger.py
+++ b/tests/e2e/integrations/test_liger.py
@@ -50,6 +50,7 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -96,6 +97,7 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py
index 247ae3bac..dceecea9f 100644
--- a/tests/e2e/integrations/test_llm_compressor.py
+++ b/tests/e2e/integrations/test_llm_compressor.py
@@ -81,6 +81,7 @@ class TestLLMCompressorIntegration:
},
"save_compressed": save_compressed,
},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py
index 5593c7eb6..80098e684 100644
--- a/tests/e2e/multigpu/patched/test_sp.py
+++ b/tests/e2e/multigpu/patched/test_sp.py
@@ -69,6 +69,7 @@ class TestSequenceParallelism:
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py
index bdf5ada6b..cbdf8de96 100644
--- a/tests/e2e/multigpu/solo/test_flex.py
+++ b/tests/e2e/multigpu/solo/test_flex.py
@@ -61,6 +61,7 @@ class TestPackedFlex:
"max_steps": 2,
"use_tensorboard": True,
"save_strategy": "no",
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py
index c04734345..d022ae2d9 100644
--- a/tests/e2e/multigpu/solo/test_grpo.py
+++ b/tests/e2e/multigpu/solo/test_grpo.py
@@ -223,6 +223,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -317,6 +318,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -409,6 +411,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py
index d6429cf63..4f86278ff 100644
--- a/tests/e2e/multigpu/test_eval.py
+++ b/tests/e2e/multigpu/test_eval.py
@@ -67,6 +67,7 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -138,6 +139,7 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py
index 3868d90f0..4a7b101a8 100644
--- a/tests/e2e/multigpu/test_gemma3.py
+++ b/tests/e2e/multigpu/test_gemma3.py
@@ -71,6 +71,7 @@ class TestMultiGPUGemma3:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py
index f0c74fbf8..aab14dcc4 100644
--- a/tests/e2e/multigpu/test_llama.py
+++ b/tests/e2e/multigpu/test_llama.py
@@ -69,6 +69,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -135,6 +136,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -210,6 +212,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -289,6 +292,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -365,6 +369,7 @@ class TestMultiGPULlama:
},
"use_tensorboard": True,
"seed": 42,
+ "save_first_step": False,
}
)
@@ -442,6 +447,7 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -520,6 +526,7 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if attention_backend == "flash":
@@ -605,6 +612,7 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -689,6 +697,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
+ "save_first_step": False,
**adapter,
}
)
@@ -765,6 +774,7 @@ class TestMultiGPULlama:
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
"seed": 42,
+ "save_first_step": False,
**adapter,
}
)
@@ -840,6 +850,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
+ "save_first_step": False,
**adapter,
}
)
@@ -908,6 +919,7 @@ class TestMultiGPULlama:
"save_safetensors": True,
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py
index 43a722b48..dd1422296 100644
--- a/tests/e2e/multigpu/test_ray.py
+++ b/tests/e2e/multigpu/test_ray.py
@@ -56,6 +56,7 @@ class TestMultiGPURay:
"use_tensorboard": True,
"use_ray": True,
"ray_num_workers": 2,
+ "save_first_step": False,
}
)
@@ -115,6 +116,7 @@ class TestMultiGPURay:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py
index 08b62accc..1824443e7 100644
--- a/tests/e2e/patched/test_4d_multipack_llama.py
+++ b/tests/e2e/patched/test_4d_multipack_llama.py
@@ -55,6 +55,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -102,6 +103,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py
index d494ed1eb..3d5b3dc56 100644
--- a/tests/e2e/patched/test_activation_checkpointing.py
+++ b/tests/e2e/patched/test_activation_checkpointing.py
@@ -69,6 +69,7 @@ class TestActivationCheckpointing:
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py
index ca8b21178..38099b220 100644
--- a/tests/e2e/patched/test_fa_xentropy.py
+++ b/tests/e2e/patched/test_fa_xentropy.py
@@ -62,6 +62,7 @@ class TestFAXentropyLlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py
index a593b0791..ef31b11c7 100644
--- a/tests/e2e/patched/test_falcon_samplepack.py
+++ b/tests/e2e/patched/test_falcon_samplepack.py
@@ -58,6 +58,7 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -99,6 +100,7 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_flattening.py b/tests/e2e/patched/test_flattening.py
index f77a1fbe5..fdaab558d 100644
--- a/tests/e2e/patched/test_flattening.py
+++ b/tests/e2e/patched/test_flattening.py
@@ -61,6 +61,7 @@ class TestFAFlattening:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py
index 1bbc82a38..a3fe591ee 100644
--- a/tests/e2e/patched/test_fused_llama.py
+++ b/tests/e2e/patched/test_fused_llama.py
@@ -53,6 +53,7 @@ class TestFusedLlama(unittest.TestCase):
"max_steps": 10,
"save_steps": 5,
"eval_steps": 5,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py
index d2dcc5e4b..ba5556a59 100644
--- a/tests/e2e/patched/test_llama_s2_attention.py
+++ b/tests/e2e/patched/test_llama_s2_attention.py
@@ -58,6 +58,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py
index 5df6bfecc..fdf6adbc6 100644
--- a/tests/e2e/patched/test_lora_llama_multipack.py
+++ b/tests/e2e/patched/test_lora_llama_multipack.py
@@ -55,6 +55,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -108,6 +109,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py
index 442089bae..bea0f9c68 100644
--- a/tests/e2e/patched/test_mistral_samplepack.py
+++ b/tests/e2e/patched/test_mistral_samplepack.py
@@ -56,6 +56,7 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -97,6 +98,7 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py
index 5f778660b..09e427abd 100644
--- a/tests/e2e/patched/test_mixtral_samplepack.py
+++ b/tests/e2e/patched/test_mixtral_samplepack.py
@@ -52,6 +52,7 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -90,6 +91,7 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py
index 5ea88b001..b90be23e4 100644
--- a/tests/e2e/patched/test_model_patches.py
+++ b/tests/e2e/patched/test_model_patches.py
@@ -45,6 +45,7 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -78,6 +79,7 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py
index d4f59a128..4769319ae 100644
--- a/tests/e2e/patched/test_peft_embeddings.py
+++ b/tests/e2e/patched/test_peft_embeddings.py
@@ -49,6 +49,7 @@ class TestLlamaPeftEmbeddings:
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py
index d241ce185..1f0ddd630 100644
--- a/tests/e2e/patched/test_phi_multipack.py
+++ b/tests/e2e/patched/test_phi_multipack.py
@@ -54,6 +54,7 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -105,6 +106,7 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py
index 363956733..54b8245ee 100644
--- a/tests/e2e/patched/test_resume.py
+++ b/tests/e2e/patched/test_resume.py
@@ -58,6 +58,7 @@ class TestResumeLlama:
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py
index 2b4d11b30..4a2c69d45 100644
--- a/tests/e2e/patched/test_sp.py
+++ b/tests/e2e/patched/test_sp.py
@@ -47,6 +47,7 @@ def fixture_cfg():
"special_tokens": {
"pad_token": "<|endoftext|>",
},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py
index 69171481c..2c8ee4eb0 100644
--- a/tests/e2e/patched/test_unsloth_qlora.py
+++ b/tests/e2e/patched/test_unsloth_qlora.py
@@ -62,6 +62,7 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -112,6 +113,7 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -167,6 +169,7 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py
index 279913713..76364fc0e 100644
--- a/tests/e2e/solo/test_flex.py
+++ b/tests/e2e/solo/test_flex.py
@@ -49,6 +49,7 @@ class TestPackedFlex(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py
index 7af550496..f6fcad841 100644
--- a/tests/e2e/solo/test_relora_llama.py
+++ b/tests/e2e/solo/test_relora_llama.py
@@ -65,6 +65,7 @@ class TestReLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py
index 7dfc4ae15..e4a47fb0a 100644
--- a/tests/e2e/test_deepseekv3.py
+++ b/tests/e2e/test_deepseekv3.py
@@ -67,6 +67,7 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -116,6 +117,7 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py
index 2cdb57689..a1df69535 100644
--- a/tests/e2e/test_dpo.py
+++ b/tests/e2e/test_dpo.py
@@ -56,6 +56,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -105,6 +106,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -154,6 +156,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -203,6 +206,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -251,6 +255,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -302,6 +307,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
@@ -370,6 +376,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py
index 9b65f8feb..e4a06ad14 100644
--- a/tests/e2e/test_embeddings_lr.py
+++ b/tests/e2e/test_embeddings_lr.py
@@ -48,6 +48,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
@@ -93,6 +94,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_evaluate.py b/tests/e2e/test_evaluate.py
index 6271bba28..977497e5e 100644
--- a/tests/e2e/test_evaluate.py
+++ b/tests/e2e/test_evaluate.py
@@ -36,6 +36,7 @@ class TestE2eEvaluate:
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py
index 4f88e740c..5be6efcf6 100644
--- a/tests/e2e/test_falcon.py
+++ b/tests/e2e/test_falcon.py
@@ -60,6 +60,7 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -115,6 +116,7 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
@@ -156,6 +158,7 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py
index 3f00a1384..ef38d028d 100644
--- a/tests/e2e/test_gemma3_text.py
+++ b/tests/e2e/test_gemma3_text.py
@@ -63,6 +63,7 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -113,6 +114,7 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py
index 2b180029c..1e6df0be9 100644
--- a/tests/e2e/test_llama.py
+++ b/tests/e2e/test_llama.py
@@ -45,6 +45,7 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -92,6 +93,7 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -136,6 +138,7 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
@@ -176,6 +179,7 @@ class TestLlama:
"batch_flattening": True,
"bf16": True,
"save_safetensors": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py
index fdebf2173..bd5502300 100644
--- a/tests/e2e/test_llama_pretrain.py
+++ b/tests/e2e/test_llama_pretrain.py
@@ -53,6 +53,7 @@ class TestPretrainLlama:
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py
index ad4a83c6a..760759bca 100644
--- a/tests/e2e/test_llama_vision.py
+++ b/tests/e2e/test_llama_vision.py
@@ -54,6 +54,7 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py
index 301565302..7e0ff46cf 100644
--- a/tests/e2e/test_lora_llama.py
+++ b/tests/e2e/test_lora_llama.py
@@ -49,6 +49,7 @@ class TestLoraLlama(unittest.TestCase):
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py
index 1824619a6..73d3bdc26 100644
--- a/tests/e2e/test_mamba.py
+++ b/tests/e2e/test_mamba.py
@@ -51,6 +51,7 @@ class TestMamba(unittest.TestCase):
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py
index 5d9b8ba8c..f47f794e0 100644
--- a/tests/e2e/test_mistral.py
+++ b/tests/e2e/test_mistral.py
@@ -55,6 +55,7 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -95,6 +96,7 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py
index 761e59391..3fe2bf70f 100644
--- a/tests/e2e/test_mixtral.py
+++ b/tests/e2e/test_mixtral.py
@@ -61,6 +61,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -116,6 +117,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -170,6 +172,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -228,6 +231,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
@@ -273,6 +277,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py
index 53ef86022..1d233a201 100644
--- a/tests/e2e/test_optimizers.py
+++ b/tests/e2e/test_optimizers.py
@@ -55,6 +55,7 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
@@ -100,6 +101,7 @@ class TestCustomOptimizers(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adopt_adamw",
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
@@ -146,6 +148,7 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
+ "save_first_step": False,
}
)
@@ -184,6 +187,7 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "constant",
"save_safetensors": True,
"max_steps": 10,
+ "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -232,6 +236,7 @@ class TestCustomOptimizers(unittest.TestCase):
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py
index cc2db72e0..aec9d95f8 100644
--- a/tests/e2e/test_packing_loss.py
+++ b/tests/e2e/test_packing_loss.py
@@ -48,6 +48,7 @@ class TestPackedLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py
index 88fda9191..ab3a63674 100644
--- a/tests/e2e/test_phi.py
+++ b/tests/e2e/test_phi.py
@@ -53,6 +53,7 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -102,6 +103,7 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py
index abfe1b0c5..bd9eec48b 100644
--- a/tests/e2e/test_process_reward_model_smollm2.py
+++ b/tests/e2e/test_process_reward_model_smollm2.py
@@ -49,6 +49,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"use_tensorboard": True,
"special_tokens": {"pad_token": "<|endoftext|>"},
"seed": 42,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py
index ef726079d..139ae155a 100644
--- a/tests/e2e/test_qat.py
+++ b/tests/e2e/test_qat.py
@@ -57,6 +57,7 @@ class TestQATLlama:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -115,6 +116,7 @@ class TestQATLlama:
"weight_dtype": "int8",
"group_size": 8,
},
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py
index aa8b9f6c0..59267d14d 100644
--- a/tests/e2e/test_qwen.py
+++ b/tests/e2e/test_qwen.py
@@ -59,6 +59,7 @@ class TestE2eQwen:
"bf16": "auto",
"tf32": True,
"gradient_checkpointing": True,
+ "save_first_step": False,
}
)
diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py
index 5d52bcc86..82513f99f 100644
--- a/tests/e2e/test_reward_model_smollm2.py
+++ b/tests/e2e/test_reward_model_smollm2.py
@@ -58,6 +58,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
"use_tensorboard": True,
+ "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_save_first_step.py b/tests/e2e/test_save_first_step.py
new file mode 100644
index 000000000..5bbd2302b
--- /dev/null
+++ b/tests/e2e/test_save_first_step.py
@@ -0,0 +1,102 @@
+"""
+E2E tests for relora llama
+"""
+
+import unittest
+from pathlib import Path
+
+import pytest
+
+from axolotl.common.datasets import load_datasets
+from axolotl.train import train
+from axolotl.utils.config import normalize_config, validate_config
+from axolotl.utils.dict import DictDefault
+
+from .utils import check_model_output_exists, with_temp_dir
+
+
+class TestSaveFirstStepCallback(unittest.TestCase):
+ """Test cases for save_first_step callback config."""
+
+ @with_temp_dir
+ def test_save_first_step(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "sequence_len": 512,
+ "val_set_size": 0.02,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_bnb_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "sample_packing": True,
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": True,
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
+
+ @with_temp_dir
+ def test_no_save_first_step(self, temp_dir):
+ # pylint: disable=duplicate-code
+ cfg = DictDefault(
+ {
+ "base_model": "HuggingFaceTB/SmolLM2-135M",
+ "tokenizer_type": "AutoTokenizer",
+ "sequence_len": 512,
+ "val_set_size": 0.02,
+ "special_tokens": {
+ "pad_token": "<|endoftext|>",
+ },
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ },
+ ],
+ "num_epochs": 1,
+ "max_steps": 3,
+ "micro_batch_size": 2,
+ "gradient_accumulation_steps": 1,
+ "output_dir": temp_dir,
+ "learning_rate": 0.00001,
+ "optimizer": "adamw_bnb_8bit",
+ "lr_scheduler": "cosine",
+ "flash_attention": True,
+ "sample_packing": True,
+ "bf16": True,
+ "save_safetensors": True,
+ "save_first_step": False,
+ }
+ )
+
+ cfg = validate_config(cfg)
+ normalize_config(cfg)
+ dataset_meta = load_datasets(cfg=cfg)
+
+ train(cfg=cfg, dataset_meta=dataset_meta)
+ with pytest.raises(AssertionError):
+ check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py
index e98378f08..8f7a13aee 100644
--- a/tests/e2e/test_schedulers.py
+++ b/tests/e2e/test_schedulers.py
@@ -51,6 +51,7 @@ class TestCustomSchedulers(unittest.TestCase):
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
+ "save_first_step": False,
}
)