diff --git a/examples/cloud/modal.yaml b/examples/cloud/modal.yaml
index bbe8785f1..195031494 100644
--- a/examples/cloud/modal.yaml
+++ b/examples/cloud/modal.yaml
@@ -26,5 +26,3 @@ timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml
index da2777270..4a30e9a77 100644
--- a/examples/cohere/command-r-7b-qlora.yml
+++ b/examples/cohere/command-r-7b-qlora.yml
@@ -35,6 +35,7 @@ wandb_watch:
wandb_name:
wandb_log_model:
+
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -55,5 +56,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
index 1a051b98b..2c0495ced 100644
--- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml
@@ -56,5 +56,3 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
index 807342641..de9c956e0 100644
--- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
+++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml
@@ -56,5 +56,3 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml
index 78bf6b179..0ed97db36 100644
--- a/examples/deepseek-v2/fft-fsdp-16b.yaml
+++ b/examples/deepseek-v2/fft-fsdp-16b.yaml
@@ -55,5 +55,3 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
index da1d9aefd..34dbeaafe 100644
--- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml
+++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml
@@ -79,5 +79,3 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml
index 9d92e8662..dc0051bd5 100644
--- a/examples/devstral/devstral-small-qlora.yml
+++ b/examples/devstral/devstral-small-qlora.yml
@@ -62,5 +62,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
index 484c31fec..1dd901154 100644
--- a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
@@ -69,5 +69,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
index dea2a6e6d..24dc7cae3 100644
--- a/examples/falcon-h1/falcon-h1-1b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml
@@ -46,6 +46,7 @@ wandb_watch:
wandb_name:
wandb_log_model:
+
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
@@ -68,5 +69,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
index b187efbf6..43eb1967b 100644
--- a/examples/falcon-h1/falcon-h1-34b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml
@@ -69,5 +69,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
index 4d981ad95..00929bbf0 100644
--- a/examples/falcon-h1/falcon-h1-3b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml
@@ -69,5 +69,3 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
index 5ee13facd..e2640de7b 100644
--- a/examples/falcon-h1/falcon-h1-500m-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml
@@ -69,5 +69,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
index 4b665c3cd..183e423b5 100644
--- a/examples/falcon-h1/falcon-h1-7b-qlora.yaml
+++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml
@@ -69,5 +69,3 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml
index 68d213fad..cb96a32c1 100644
--- a/examples/gemma2/qlora.yml
+++ b/examples/gemma2/qlora.yml
@@ -60,5 +60,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml
index 624ebdcd2..ce01a4572 100644
--- a/examples/gemma2/reward-model.yaml
+++ b/examples/gemma2/reward-model.yaml
@@ -50,5 +50,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml
index 99921770d..217c887aa 100644
--- a/examples/gemma3/gemma-3-1b-qlora.yml
+++ b/examples/gemma3/gemma-3-1b-qlora.yml
@@ -66,5 +66,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml
index 025cb9240..d78559ae3 100644
--- a/examples/gemma3/gemma-3-4b-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-qlora.yml
@@ -60,5 +60,3 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml
index e9e606b69..183eb88e8 100644
--- a/examples/gemma3/gemma-3-4b-vision-qlora.yml
+++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml
@@ -62,5 +62,3 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml
index 8973cedd4..86d9b43f8 100644
--- a/examples/glm4/qlora-32b.yaml
+++ b/examples/glm4/qlora-32b.yaml
@@ -60,5 +60,3 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml
index 494154886..2cb0eea41 100644
--- a/examples/jamba/qlora.yaml
+++ b/examples/jamba/qlora.yaml
@@ -54,5 +54,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml
index 64db8f2ff..d13ce6483 100644
--- a/examples/jamba/qlora_deepspeed.yaml
+++ b/examples/jamba/qlora_deepspeed.yaml
@@ -55,5 +55,3 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml
index fda30e2d2..6badaba19 100644
--- a/examples/jamba/qlora_fsdp_large.yaml
+++ b/examples/jamba/qlora_fsdp_large.yaml
@@ -64,5 +64,3 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/lfm2/lfm2-350m-fft.yaml
index 74c90c1e1..95961557e 100644
--- a/examples/lfm2/lfm2-350m-fft.yaml
+++ b/examples/lfm2/lfm2-350m-fft.yaml
@@ -46,5 +46,3 @@ evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml
index c44cd2230..86b1b6a21 100644
--- a/examples/llama-2/fft_optimized.yml
+++ b/examples/llama-2/fft_optimized.yml
@@ -55,5 +55,3 @@ saves_per_epoch: 1
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml
index 580fabdf8..0f1b34016 100644
--- a/examples/llama-2/gptq-lora.yml
+++ b/examples/llama-2/gptq-lora.yml
@@ -64,5 +64,3 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml
index a44e261be..a76a792ae 100644
--- a/examples/llama-2/lisa.yml
+++ b/examples/llama-2/lisa.yml
@@ -60,5 +60,3 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml
index 085627f63..22dbf2d99 100644
--- a/examples/llama-2/loftq.yml
+++ b/examples/llama-2/loftq.yml
@@ -52,5 +52,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml
index 759fce044..679aed3a9 100644
--- a/examples/llama-2/lora.yml
+++ b/examples/llama-2/lora.yml
@@ -52,5 +52,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml
index 3bf30120b..a42eabd4b 100644
--- a/examples/llama-2/qlora-fsdp.yml
+++ b/examples/llama-2/qlora-fsdp.yml
@@ -67,5 +67,3 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml
index 09596c71e..de65928bc 100644
--- a/examples/llama-2/qlora.yml
+++ b/examples/llama-2/qlora.yml
@@ -53,5 +53,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml
index ca8b14a1c..e0a5f7068 100644
--- a/examples/llama-2/relora.yml
+++ b/examples/llama-2/relora.yml
@@ -58,5 +58,3 @@ special_tokens:
bos_token: ""
eos_token: ""
unk_token: ""
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml
index 64d749b5a..2b0ae2c70 100644
--- a/examples/llama-3-vision/lora-11b.yaml
+++ b/examples/llama-3-vision/lora-11b.yaml
@@ -57,5 +57,3 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml
index 08d8ee5c1..5d979c96c 100644
--- a/examples/llama-3/3b-qat-fsdp2.yaml
+++ b/examples/llama-3/3b-qat-fsdp2.yaml
@@ -77,5 +77,3 @@ fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml
index e2808935f..eccfa6d8c 100644
--- a/examples/llama-3/fft-8b-liger-fsdp.yaml
+++ b/examples/llama-3/fft-8b-liger-fsdp.yaml
@@ -72,5 +72,3 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml
index 2dfe6d492..fdae3e6c4 100644
--- a/examples/llama-3/fft-8b.yaml
+++ b/examples/llama-3/fft-8b.yaml
@@ -42,5 +42,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml
index 10ab2a320..51f1c768b 100644
--- a/examples/llama-3/instruct-dpo-lora-8b.yml
+++ b/examples/llama-3/instruct-dpo-lora-8b.yml
@@ -71,5 +71,3 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml
index 83b7f9a37..acab862f6 100644
--- a/examples/llama-3/instruct-lora-8b.yml
+++ b/examples/llama-3/instruct-lora-8b.yml
@@ -64,5 +64,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml
index b20dbad84..10e9747cb 100644
--- a/examples/llama-3/lora-1b-deduplicate-dpo.yml
+++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml
@@ -83,5 +83,3 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml
index 67e518184..630ec92f6 100644
--- a/examples/llama-3/lora-1b-deduplicate-sft.yml
+++ b/examples/llama-3/lora-1b-deduplicate-sft.yml
@@ -61,5 +61,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml
index 92a948c2e..a2d07ca49 100644
--- a/examples/llama-3/lora-1b-kernels.yml
+++ b/examples/llama-3/lora-1b-kernels.yml
@@ -65,5 +65,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml
index 178a1fb89..bb23164eb 100644
--- a/examples/llama-3/lora-1b-ray.yml
+++ b/examples/llama-3/lora-1b-ray.yml
@@ -64,5 +64,3 @@ special_tokens:
use_ray: true
ray_num_workers: 4
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
index c4ce3eb0f..769dd32e6 100644
--- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml
+++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml
@@ -63,5 +63,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml
index 82085483f..acc17e21f 100644
--- a/examples/llama-3/lora-1b.yml
+++ b/examples/llama-3/lora-1b.yml
@@ -60,5 +60,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml
index c39389755..ad50cd38a 100644
--- a/examples/llama-3/lora-8b.yml
+++ b/examples/llama-3/lora-8b.yml
@@ -57,5 +57,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml
index f156e23d3..89a51ea68 100644
--- a/examples/llama-3/qlora-1b-kto.yaml
+++ b/examples/llama-3/qlora-1b-kto.yaml
@@ -61,5 +61,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml
index 6b76ea8d9..5c8fe6628 100644
--- a/examples/llama-3/qlora-1b.yml
+++ b/examples/llama-3/qlora-1b.yml
@@ -62,5 +62,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml
index 1ee922b59..2b7d51925 100644
--- a/examples/llama-3/qlora-fsdp-405b.yaml
+++ b/examples/llama-3/qlora-fsdp-405b.yaml
@@ -60,5 +60,3 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml
index 5edd8353a..412b6721c 100644
--- a/examples/llama-3/qlora-fsdp-70b.yaml
+++ b/examples/llama-3/qlora-fsdp-70b.yaml
@@ -69,5 +69,3 @@ fsdp_config:
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|end_of_text|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml
index a674eca27..4cc9fc3db 100644
--- a/examples/llama-3/qlora.yml
+++ b/examples/llama-3/qlora.yml
@@ -54,5 +54,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml
index 8577a19d2..1bbb88028 100644
--- a/examples/llama-3/sparse-finetuning.yaml
+++ b/examples/llama-3/sparse-finetuning.yaml
@@ -75,5 +75,3 @@ llmcompressor:
]
start: 0
save_compressed: true
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
index d4a038e11..2be94f4ef 100644
--- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml
@@ -86,5 +86,3 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
index bea10d979..eeae872a6 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml
@@ -90,5 +90,3 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
index 737d93812..17ad70634 100644
--- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml
@@ -83,5 +83,3 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
index 390be5af7..eff708e4d 100644
--- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
+++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml
@@ -86,5 +86,3 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
index b319349c4..9a411883e 100644
--- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
+++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml
@@ -84,5 +84,3 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml
index 6be3988ef..20352f81e 100644
--- a/examples/llama-4/scout-qlora-single-h100-flex.yaml
+++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml
@@ -82,5 +82,3 @@ weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
index a67936cf1..9fbd34107 100644
--- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
+++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml
@@ -87,5 +87,3 @@ fsdp_config:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot|>
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml
index a4bac8987..5198c8e74 100644
--- a/examples/llava/lora-7b.yaml
+++ b/examples/llava/lora-7b.yaml
@@ -53,5 +53,3 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml
index b23d2309a..b10e8baf6 100644
--- a/examples/magistral/magistral-small-fsdp-qlora.yaml
+++ b/examples/magistral/magistral-small-fsdp-qlora.yaml
@@ -70,5 +70,3 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
fsdp_activation_checkpointing: true
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml
index f0fce014f..e3e746f22 100644
--- a/examples/magistral/magistral-small-qlora.yaml
+++ b/examples/magistral/magistral-small-qlora.yaml
@@ -61,5 +61,3 @@ flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml
index 2261bd215..3d4583932 100644
--- a/examples/mamba/config.yml
+++ b/examples/mamba/config.yml
@@ -48,5 +48,3 @@ weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml
index e9bcbb7d6..f626a92a1 100644
--- a/examples/mistral/bigstral-ds-zero3.yaml
+++ b/examples/mistral/bigstral-ds-zero3.yaml
@@ -53,5 +53,3 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml
index 8c4d80f79..15edffb44 100644
--- a/examples/mistral/config.yml
+++ b/examples/mistral/config.yml
@@ -43,5 +43,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml
index d54c3e30b..e6f46affb 100644
--- a/examples/mistral/lora-mps.yml
+++ b/examples/mistral/lora-mps.yml
@@ -64,5 +64,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml
index 161255468..9af4274fd 100644
--- a/examples/mistral/lora.yml
+++ b/examples/mistral/lora.yml
@@ -64,5 +64,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/mistral-dpo-qlora.yml
index 8d0378690..af707973f 100644
--- a/examples/mistral/mistral-dpo-qlora.yml
+++ b/examples/mistral/mistral-dpo-qlora.yml
@@ -80,5 +80,3 @@ weight_decay: 0.0
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml
index cec958c54..e234b19a2 100644
--- a/examples/mistral/mistral-qlora-fsdp.yml
+++ b/examples/mistral/mistral-qlora-fsdp.yml
@@ -74,5 +74,3 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml
index f37dc09fa..6c0212b7c 100644
--- a/examples/mistral/mistral-qlora-orpo.yml
+++ b/examples/mistral/mistral-qlora-orpo.yml
@@ -69,5 +69,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml
index 4a492c595..3e3b45862 100644
--- a/examples/mistral/mistral-small-3.1-24B-lora.yml
+++ b/examples/mistral/mistral-small-3.1-24B-lora.yml
@@ -56,5 +56,3 @@ evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
index 64ef9930c..af6ba5a76 100644
--- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml
@@ -72,5 +72,3 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml
index c8d0a2711..b1843a138 100644
--- a/examples/mistral/mixtral-qlora-fsdp.yml
+++ b/examples/mistral/mixtral-qlora-fsdp.yml
@@ -77,5 +77,3 @@ fsdp_config:
fsdp_forward_prefetch: false
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml
index 5be9b4db8..4c256420c 100644
--- a/examples/mistral/mixtral.yml
+++ b/examples/mistral/mixtral.yml
@@ -81,5 +81,3 @@ saves_per_epoch: 1
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml
index 100e4464f..25e1d7155 100644
--- a/examples/mistral/mixtral_22.yml
+++ b/examples/mistral/mixtral_22.yml
@@ -51,5 +51,3 @@ special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml
index 08df36e15..607e33701 100644
--- a/examples/mistral/qlora.yml
+++ b/examples/mistral/qlora.yml
@@ -64,5 +64,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml
index 57f65d966..9bcbbeee0 100644
--- a/examples/orpheus/finetune.yml
+++ b/examples/orpheus/finetune.yml
@@ -50,5 +50,3 @@ weight_decay: 0.05
special_tokens:
pad_token:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml
index 9f3bbdf53..ad4ce9cd4 100644
--- a/examples/phi/lora-3.5.yaml
+++ b/examples/phi/lora-3.5.yaml
@@ -63,5 +63,3 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 4
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml
index fc6d649d7..1562a7353 100644
--- a/examples/phi/phi-ft.yml
+++ b/examples/phi/phi-ft.yml
@@ -57,5 +57,3 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml
index ccd92c817..4cd53db97 100644
--- a/examples/phi/phi-qlora.yml
+++ b/examples/phi/phi-qlora.yml
@@ -60,5 +60,3 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml
index 853250ccb..ca733cc71 100644
--- a/examples/phi/phi2-ft.yml
+++ b/examples/phi/phi2-ft.yml
@@ -57,5 +57,3 @@ weight_decay: 0.1
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml
index 130298bc0..d0d14fea6 100644
--- a/examples/phi/phi3-ft-fsdp.yml
+++ b/examples/phi/phi3-ft-fsdp.yml
@@ -71,5 +71,3 @@ fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
pad_token: "<|endoftext|>"
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml
index 42b87e8d0..17c48da6f 100644
--- a/examples/phi/phi3-ft.yml
+++ b/examples/phi/phi3-ft.yml
@@ -59,5 +59,3 @@ warmup_ratio: 0.2
debug: true
weight_decay: 0.1
resize_token_embeddings_to_32x: true
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml
index ea769d202..6ad0a5e99 100644
--- a/examples/pixtral/lora-12b.yml
+++ b/examples/pixtral/lora-12b.yml
@@ -55,5 +55,3 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml
index 8ea608199..e8932b968 100644
--- a/examples/qwen2-vl/lora-7b.yaml
+++ b/examples/qwen2-vl/lora-7b.yaml
@@ -53,5 +53,3 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml
index 69a74ae4a..bd896c2b3 100644
--- a/examples/qwen2/dpo.yaml
+++ b/examples/qwen2/dpo.yaml
@@ -54,5 +54,3 @@ warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml
index af188f75d..4afa24f3c 100644
--- a/examples/qwen2/prm.yaml
+++ b/examples/qwen2/prm.yaml
@@ -55,5 +55,3 @@ eval_steps: 100
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml
index 861ce5517..ed2670ab6 100644
--- a/examples/qwen2/qlora-fsdp.yaml
+++ b/examples/qwen2/qlora-fsdp.yaml
@@ -67,5 +67,3 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml
index 1854b8216..822407a1f 100644
--- a/examples/qwen2/reward-model.yaml
+++ b/examples/qwen2/reward-model.yaml
@@ -26,6 +26,7 @@ wandb_watch:
wandb_name:
wandb_log_model:
+
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
@@ -49,5 +50,3 @@ evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml
index 13a97dec3..25d02805f 100644
--- a/examples/qwen2_5-vl/lora-7b.yaml
+++ b/examples/qwen2_5-vl/lora-7b.yaml
@@ -53,5 +53,3 @@ warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml
index 1f148ece5..45a4395ac 100644
--- a/examples/qwen3/32b-qlora.yaml
+++ b/examples/qwen3/32b-qlora.yaml
@@ -67,5 +67,3 @@ evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml
index e4d0ed4fb..6832b6af7 100644
--- a/examples/qwen3/8b-qat-fsdp2.yml
+++ b/examples/qwen3/8b-qat-fsdp2.yml
@@ -76,5 +76,3 @@ fsdp_config:
fsdp_activation_checkpointing: true
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml
index 762f9648d..dc3377b4f 100644
--- a/examples/qwen3/qlora-fsdp.yaml
+++ b/examples/qwen3/qlora-fsdp.yaml
@@ -66,5 +66,3 @@ fsdp_config:
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
-
-# save_first_step: true # uncomment this to validate checkpoint saving works with your config
diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py
index d3a3b3242..4df010040 100644
--- a/src/axolotl/core/builders/base.py
+++ b/src/axolotl/core/builders/base.py
@@ -36,7 +36,6 @@ from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
- SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -136,8 +135,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
- if self.cfg.save_first_step:
- callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))
diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py
index bb777fc90..5f804d6af 100644
--- a/src/axolotl/utils/callbacks/__init__.py
+++ b/src/axolotl/utils/callbacks/__init__.py
@@ -64,7 +64,7 @@ class SaveBetterTransformerModelCallback(
state: TrainerState,
control: TrainerControl,
**kwargs,
- ) -> TrainerControl:
+ ):
# Save
if (
args.save_strategy == IntervalStrategy.STEPS
@@ -100,11 +100,11 @@ class GPUStatsCallback(
def on_step_end(
self,
- args: TrainingArguments, # pylint: disable=unused-argument
+ args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
- ) -> TrainerControl:
+ ):
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
@@ -116,17 +116,18 @@ class LossWatchDogCallback(TrainerCallback):
def __init__(self, cfg):
self.cfg = cfg
+ self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
- args: TrainingArguments, # pylint: disable=unused-argument
+ _args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
- ) -> TrainerControl:
+ ):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
@@ -140,21 +141,6 @@ class LossWatchDogCallback(TrainerCallback):
return control
-class SaveModelOnFirstStepCallback(TrainerCallback):
- """Callback to save the model on the first step of training if enabled"""
-
- def on_step_end(
- self,
- args: TrainingArguments, # pylint: disable=unused-argument
- state: TrainerState,
- control: TrainerControl,
- **_kwargs,
- ) -> TrainerControl:
- if state.global_step == 1:
- control.should_save = True
- return control
-
-
def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy")
abcd_idx = [
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index e20cdaf47..909fd637c 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -706,7 +706,6 @@ class AxolotlInputConfig(
"description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`"
},
)
-
save_steps: int | float | None = Field(
default=None,
json_schema_extra={
@@ -728,13 +727,6 @@ class AxolotlInputConfig(
save_total_limit: int | None = Field(
default=None, json_schema_extra={"description": "Checkpoints saved at a time"}
)
- save_first_step: bool | None = Field(
- default=None,
- json_schema_extra={
- "description": "Whether to checkpoint a model after the first step of training. Defaults to False."
- },
- )
-
logging_steps: int | None = Field(
default=None, json_schema_extra={"description": "Logging frequency"}
)
diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py
index 34e6c9644..790b34f3e 100644
--- a/tests/e2e/integrations/test_cut_cross_entropy.py
+++ b/tests/e2e/integrations/test_cut_cross_entropy.py
@@ -44,7 +44,6 @@ def min_cfg(temp_dir):
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
@@ -99,7 +98,6 @@ class TestCutCrossEntropyIntegration:
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py
index 8743efb98..4734449fe 100644
--- a/tests/e2e/integrations/test_hooks.py
+++ b/tests/e2e/integrations/test_hooks.py
@@ -153,7 +153,6 @@ class TestPluginHooks:
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
- "save_first_step": False,
}
)
diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py
index 1ac3b537e..212450e89 100644
--- a/tests/e2e/integrations/test_kd.py
+++ b/tests/e2e/integrations/test_kd.py
@@ -67,7 +67,6 @@ def min_cfg(temp_dir):
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
- "save_first_step": False,
}
diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py
index b1f5befdd..6ab3d7ab8 100644
--- a/tests/e2e/integrations/test_liger.py
+++ b/tests/e2e/integrations/test_liger.py
@@ -50,7 +50,6 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
- "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -97,7 +96,6 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
- "save_first_step": False,
}
)
# pylint: disable=duplicate-code
diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py
index dceecea9f..247ae3bac 100644
--- a/tests/e2e/integrations/test_llm_compressor.py
+++ b/tests/e2e/integrations/test_llm_compressor.py
@@ -81,7 +81,6 @@ class TestLLMCompressorIntegration:
},
"save_compressed": save_compressed,
},
- "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py
index 80098e684..5593c7eb6 100644
--- a/tests/e2e/multigpu/patched/test_sp.py
+++ b/tests/e2e/multigpu/patched/test_sp.py
@@ -69,7 +69,6 @@ class TestSequenceParallelism:
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py
index cbdf8de96..bdf5ada6b 100644
--- a/tests/e2e/multigpu/solo/test_flex.py
+++ b/tests/e2e/multigpu/solo/test_flex.py
@@ -61,7 +61,6 @@ class TestPackedFlex:
"max_steps": 2,
"use_tensorboard": True,
"save_strategy": "no",
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py
index d022ae2d9..c04734345 100644
--- a/tests/e2e/multigpu/solo/test_grpo.py
+++ b/tests/e2e/multigpu/solo/test_grpo.py
@@ -223,7 +223,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
- "save_first_step": False,
}
)
@@ -318,7 +317,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
- "save_first_step": False,
}
)
@@ -411,7 +409,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py
index 4f86278ff..d6429cf63 100644
--- a/tests/e2e/multigpu/test_eval.py
+++ b/tests/e2e/multigpu/test_eval.py
@@ -67,7 +67,6 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
- "save_first_step": False,
}
)
@@ -139,7 +138,6 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py
index 4a7b101a8..3868d90f0 100644
--- a/tests/e2e/multigpu/test_gemma3.py
+++ b/tests/e2e/multigpu/test_gemma3.py
@@ -71,7 +71,6 @@ class TestMultiGPUGemma3:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py
index aab14dcc4..f0c74fbf8 100644
--- a/tests/e2e/multigpu/test_llama.py
+++ b/tests/e2e/multigpu/test_llama.py
@@ -69,7 +69,6 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
- "save_first_step": False,
}
)
@@ -136,7 +135,6 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
- "save_first_step": False,
}
)
@@ -212,7 +210,6 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
- "save_first_step": False,
}
)
@@ -292,7 +289,6 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
- "save_first_step": False,
}
)
@@ -369,7 +365,6 @@ class TestMultiGPULlama:
},
"use_tensorboard": True,
"seed": 42,
- "save_first_step": False,
}
)
@@ -447,7 +442,6 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
- "save_first_step": False,
}
)
@@ -526,7 +520,6 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
- "save_first_step": False,
}
)
if attention_backend == "flash":
@@ -612,7 +605,6 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
- "save_first_step": False,
}
)
@@ -697,7 +689,6 @@ class TestMultiGPULlama:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
- "save_first_step": False,
**adapter,
}
)
@@ -774,7 +765,6 @@ class TestMultiGPULlama:
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
"seed": 42,
- "save_first_step": False,
**adapter,
}
)
@@ -850,7 +840,6 @@ class TestMultiGPULlama:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
- "save_first_step": False,
**adapter,
}
)
@@ -919,7 +908,6 @@ class TestMultiGPULlama:
"save_safetensors": True,
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py
index dd1422296..43a722b48 100644
--- a/tests/e2e/multigpu/test_ray.py
+++ b/tests/e2e/multigpu/test_ray.py
@@ -56,7 +56,6 @@ class TestMultiGPURay:
"use_tensorboard": True,
"use_ray": True,
"ray_num_workers": 2,
- "save_first_step": False,
}
)
@@ -116,7 +115,6 @@ class TestMultiGPURay:
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py
index 1824443e7..08b62accc 100644
--- a/tests/e2e/patched/test_4d_multipack_llama.py
+++ b/tests/e2e/patched/test_4d_multipack_llama.py
@@ -55,7 +55,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -103,7 +102,6 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py
index 3d5b3dc56..d494ed1eb 100644
--- a/tests/e2e/patched/test_activation_checkpointing.py
+++ b/tests/e2e/patched/test_activation_checkpointing.py
@@ -69,7 +69,6 @@ class TestActivationCheckpointing:
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py
index 38099b220..ca8b21178 100644
--- a/tests/e2e/patched/test_fa_xentropy.py
+++ b/tests/e2e/patched/test_fa_xentropy.py
@@ -62,7 +62,6 @@ class TestFAXentropyLlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py
index ef31b11c7..a593b0791 100644
--- a/tests/e2e/patched/test_falcon_samplepack.py
+++ b/tests/e2e/patched/test_falcon_samplepack.py
@@ -58,7 +58,6 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -100,7 +99,6 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_flattening.py b/tests/e2e/patched/test_flattening.py
index fdaab558d..f77a1fbe5 100644
--- a/tests/e2e/patched/test_flattening.py
+++ b/tests/e2e/patched/test_flattening.py
@@ -61,7 +61,6 @@ class TestFAFlattening:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py
index a3fe591ee..1bbc82a38 100644
--- a/tests/e2e/patched/test_fused_llama.py
+++ b/tests/e2e/patched/test_fused_llama.py
@@ -53,7 +53,6 @@ class TestFusedLlama(unittest.TestCase):
"max_steps": 10,
"save_steps": 5,
"eval_steps": 5,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py
index ba5556a59..d2dcc5e4b 100644
--- a/tests/e2e/patched/test_llama_s2_attention.py
+++ b/tests/e2e/patched/test_llama_s2_attention.py
@@ -58,7 +58,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
- "save_first_step": False,
}
)
@@ -101,7 +100,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
- "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py
index fdf6adbc6..5df6bfecc 100644
--- a/tests/e2e/patched/test_lora_llama_multipack.py
+++ b/tests/e2e/patched/test_lora_llama_multipack.py
@@ -55,7 +55,6 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -109,7 +108,6 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py
index bea0f9c68..442089bae 100644
--- a/tests/e2e/patched/test_mistral_samplepack.py
+++ b/tests/e2e/patched/test_mistral_samplepack.py
@@ -56,7 +56,6 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -98,7 +97,6 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py
index 09e427abd..5f778660b 100644
--- a/tests/e2e/patched/test_mixtral_samplepack.py
+++ b/tests/e2e/patched/test_mixtral_samplepack.py
@@ -52,7 +52,6 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -91,7 +90,6 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py
index b90be23e4..5ea88b001 100644
--- a/tests/e2e/patched/test_model_patches.py
+++ b/tests/e2e/patched/test_model_patches.py
@@ -45,7 +45,6 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -79,7 +78,6 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py
index 4769319ae..d4f59a128 100644
--- a/tests/e2e/patched/test_peft_embeddings.py
+++ b/tests/e2e/patched/test_peft_embeddings.py
@@ -49,7 +49,6 @@ class TestLlamaPeftEmbeddings:
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py
index 1f0ddd630..d241ce185 100644
--- a/tests/e2e/patched/test_phi_multipack.py
+++ b/tests/e2e/patched/test_phi_multipack.py
@@ -54,7 +54,6 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
- "save_first_step": False,
}
)
@@ -106,7 +105,6 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
- "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py
index 54b8245ee..363956733 100644
--- a/tests/e2e/patched/test_resume.py
+++ b/tests/e2e/patched/test_resume.py
@@ -58,7 +58,6 @@ class TestResumeLlama:
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py
index 4a2c69d45..2b4d11b30 100644
--- a/tests/e2e/patched/test_sp.py
+++ b/tests/e2e/patched/test_sp.py
@@ -47,7 +47,6 @@ def fixture_cfg():
"special_tokens": {
"pad_token": "<|endoftext|>",
},
- "save_first_step": False,
}
)
diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py
index 2c8ee4eb0..69171481c 100644
--- a/tests/e2e/patched/test_unsloth_qlora.py
+++ b/tests/e2e/patched/test_unsloth_qlora.py
@@ -62,7 +62,6 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
- "save_first_step": False,
}
)
@@ -113,7 +112,6 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
- "save_first_step": False,
}
)
@@ -169,7 +167,6 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py
index 76364fc0e..279913713 100644
--- a/tests/e2e/solo/test_flex.py
+++ b/tests/e2e/solo/test_flex.py
@@ -49,7 +49,6 @@ class TestPackedFlex(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py
index f6fcad841..7af550496 100644
--- a/tests/e2e/solo/test_relora_llama.py
+++ b/tests/e2e/solo/test_relora_llama.py
@@ -65,7 +65,6 @@ class TestReLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py
index e4a47fb0a..7dfc4ae15 100644
--- a/tests/e2e/test_deepseekv3.py
+++ b/tests/e2e/test_deepseekv3.py
@@ -67,7 +67,6 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -117,7 +116,6 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py
index a1df69535..2cdb57689 100644
--- a/tests/e2e/test_dpo.py
+++ b/tests/e2e/test_dpo.py
@@ -56,7 +56,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
- "save_first_step": False,
}
)
@@ -106,7 +105,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
- "save_first_step": False,
}
)
@@ -156,7 +154,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
- "save_first_step": False,
}
)
@@ -206,7 +203,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
- "save_first_step": False,
}
)
@@ -255,7 +251,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
- "save_first_step": False,
}
)
@@ -307,7 +302,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
- "save_first_step": False,
}
)
@@ -376,7 +370,6 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py
index e4a06ad14..9b65f8feb 100644
--- a/tests/e2e/test_embeddings_lr.py
+++ b/tests/e2e/test_embeddings_lr.py
@@ -48,7 +48,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
- "save_first_step": False,
}
)
@@ -94,7 +93,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_evaluate.py b/tests/e2e/test_evaluate.py
index 977497e5e..6271bba28 100644
--- a/tests/e2e/test_evaluate.py
+++ b/tests/e2e/test_evaluate.py
@@ -36,7 +36,6 @@ class TestE2eEvaluate:
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py
index 5be6efcf6..4f88e740c 100644
--- a/tests/e2e/test_falcon.py
+++ b/tests/e2e/test_falcon.py
@@ -60,7 +60,6 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
@@ -116,7 +115,6 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
@@ -158,7 +156,6 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py
index ef38d028d..3f00a1384 100644
--- a/tests/e2e/test_gemma3_text.py
+++ b/tests/e2e/test_gemma3_text.py
@@ -63,7 +63,6 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -114,7 +113,6 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py
index 1e6df0be9..2b180029c 100644
--- a/tests/e2e/test_llama.py
+++ b/tests/e2e/test_llama.py
@@ -45,7 +45,6 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
- "save_first_step": False,
}
)
@@ -93,7 +92,6 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
- "save_first_step": False,
}
)
@@ -138,7 +136,6 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
- "save_first_step": False,
}
)
@@ -179,7 +176,6 @@ class TestLlama:
"batch_flattening": True,
"bf16": True,
"save_safetensors": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py
index bd5502300..fdebf2173 100644
--- a/tests/e2e/test_llama_pretrain.py
+++ b/tests/e2e/test_llama_pretrain.py
@@ -53,7 +53,6 @@ class TestPretrainLlama:
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py
index 760759bca..ad4a83c6a 100644
--- a/tests/e2e/test_llama_vision.py
+++ b/tests/e2e/test_llama_vision.py
@@ -54,7 +54,6 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
- "save_first_step": False,
}
)
@@ -101,7 +100,6 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py
index 7e0ff46cf..301565302 100644
--- a/tests/e2e/test_lora_llama.py
+++ b/tests/e2e/test_lora_llama.py
@@ -49,7 +49,6 @@ class TestLoraLlama(unittest.TestCase):
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py
index 73d3bdc26..1824619a6 100644
--- a/tests/e2e/test_mamba.py
+++ b/tests/e2e/test_mamba.py
@@ -51,7 +51,6 @@ class TestMamba(unittest.TestCase):
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py
index f47f794e0..5d9b8ba8c 100644
--- a/tests/e2e/test_mistral.py
+++ b/tests/e2e/test_mistral.py
@@ -55,7 +55,6 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
@@ -96,7 +95,6 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py
index 3fe2bf70f..761e59391 100644
--- a/tests/e2e/test_mixtral.py
+++ b/tests/e2e/test_mixtral.py
@@ -61,7 +61,6 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
@@ -117,7 +116,6 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
@@ -172,7 +170,6 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -231,7 +228,6 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
@@ -277,7 +273,6 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py
index 1d233a201..53ef86022 100644
--- a/tests/e2e/test_optimizers.py
+++ b/tests/e2e/test_optimizers.py
@@ -55,7 +55,6 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
- "save_first_step": False,
}
)
@@ -101,7 +100,6 @@ class TestCustomOptimizers(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adopt_adamw",
"lr_scheduler": "cosine",
- "save_first_step": False,
}
)
@@ -148,7 +146,6 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
- "save_first_step": False,
}
)
@@ -187,7 +184,6 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "constant",
"save_safetensors": True,
"max_steps": 10,
- "save_first_step": False,
}
)
# pylint: disable=duplicate-code
@@ -236,7 +232,6 @@ class TestCustomOptimizers(unittest.TestCase):
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py
index aec9d95f8..cc2db72e0 100644
--- a/tests/e2e/test_packing_loss.py
+++ b/tests/e2e/test_packing_loss.py
@@ -48,7 +48,6 @@ class TestPackedLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
- "save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py
index ab3a63674..88fda9191 100644
--- a/tests/e2e/test_phi.py
+++ b/tests/e2e/test_phi.py
@@ -53,7 +53,6 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -103,7 +102,6 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py
index bd9eec48b..abfe1b0c5 100644
--- a/tests/e2e/test_process_reward_model_smollm2.py
+++ b/tests/e2e/test_process_reward_model_smollm2.py
@@ -49,7 +49,6 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"use_tensorboard": True,
"special_tokens": {"pad_token": "<|endoftext|>"},
"seed": 42,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py
index 139ae155a..ef726079d 100644
--- a/tests/e2e/test_qat.py
+++ b/tests/e2e/test_qat.py
@@ -57,7 +57,6 @@ class TestQATLlama:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -116,7 +115,6 @@ class TestQATLlama:
"weight_dtype": "int8",
"group_size": 8,
},
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py
index 59267d14d..aa8b9f6c0 100644
--- a/tests/e2e/test_qwen.py
+++ b/tests/e2e/test_qwen.py
@@ -59,7 +59,6 @@ class TestE2eQwen:
"bf16": "auto",
"tf32": True,
"gradient_checkpointing": True,
- "save_first_step": False,
}
)
diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py
index 82513f99f..5d52bcc86 100644
--- a/tests/e2e/test_reward_model_smollm2.py
+++ b/tests/e2e/test_reward_model_smollm2.py
@@ -58,7 +58,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
"use_tensorboard": True,
- "save_first_step": False,
}
)
cfg = validate_config(cfg)
diff --git a/tests/e2e/test_save_first_step.py b/tests/e2e/test_save_first_step.py
deleted file mode 100644
index 5bbd2302b..000000000
--- a/tests/e2e/test_save_first_step.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""
-E2E tests for relora llama
-"""
-
-import unittest
-from pathlib import Path
-
-import pytest
-
-from axolotl.common.datasets import load_datasets
-from axolotl.train import train
-from axolotl.utils.config import normalize_config, validate_config
-from axolotl.utils.dict import DictDefault
-
-from .utils import check_model_output_exists, with_temp_dir
-
-
-class TestSaveFirstStepCallback(unittest.TestCase):
- """Test cases for save_first_step callback config."""
-
- @with_temp_dir
- def test_save_first_step(self, temp_dir):
- # pylint: disable=duplicate-code
- cfg = DictDefault(
- {
- "base_model": "HuggingFaceTB/SmolLM2-135M",
- "tokenizer_type": "AutoTokenizer",
- "sequence_len": 512,
- "val_set_size": 0.02,
- "special_tokens": {
- "pad_token": "<|endoftext|>",
- },
- "datasets": [
- {
- "path": "mhenrichsen/alpaca_2k_test",
- "type": "alpaca",
- },
- ],
- "num_epochs": 1,
- "max_steps": 3,
- "micro_batch_size": 2,
- "gradient_accumulation_steps": 1,
- "output_dir": temp_dir,
- "learning_rate": 0.00001,
- "optimizer": "adamw_bnb_8bit",
- "lr_scheduler": "cosine",
- "flash_attention": True,
- "sample_packing": True,
- "bf16": True,
- "save_safetensors": True,
- "save_first_step": True,
- }
- )
-
- cfg = validate_config(cfg)
- normalize_config(cfg)
- dataset_meta = load_datasets(cfg=cfg)
-
- train(cfg=cfg, dataset_meta=dataset_meta)
- check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
-
- @with_temp_dir
- def test_no_save_first_step(self, temp_dir):
- # pylint: disable=duplicate-code
- cfg = DictDefault(
- {
- "base_model": "HuggingFaceTB/SmolLM2-135M",
- "tokenizer_type": "AutoTokenizer",
- "sequence_len": 512,
- "val_set_size": 0.02,
- "special_tokens": {
- "pad_token": "<|endoftext|>",
- },
- "datasets": [
- {
- "path": "mhenrichsen/alpaca_2k_test",
- "type": "alpaca",
- },
- ],
- "num_epochs": 1,
- "max_steps": 3,
- "micro_batch_size": 2,
- "gradient_accumulation_steps": 1,
- "output_dir": temp_dir,
- "learning_rate": 0.00001,
- "optimizer": "adamw_bnb_8bit",
- "lr_scheduler": "cosine",
- "flash_attention": True,
- "sample_packing": True,
- "bf16": True,
- "save_safetensors": True,
- "save_first_step": False,
- }
- )
-
- cfg = validate_config(cfg)
- normalize_config(cfg)
- dataset_meta = load_datasets(cfg=cfg)
-
- train(cfg=cfg, dataset_meta=dataset_meta)
- with pytest.raises(AssertionError):
- check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py
index 8f7a13aee..e98378f08 100644
--- a/tests/e2e/test_schedulers.py
+++ b/tests/e2e/test_schedulers.py
@@ -51,7 +51,6 @@ class TestCustomSchedulers(unittest.TestCase):
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
- "save_first_step": False,
}
)