diff --git a/examples/cloud/modal.yaml b/examples/cloud/modal.yaml index 195031494..bbe8785f1 100644 --- a/examples/cloud/modal.yaml +++ b/examples/cloud/modal.yaml @@ -26,3 +26,5 @@ timeout: 86400 # Preprocess specific configurations memory_preprocess: 32 timeout_preprocess: 14400 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml index 4a30e9a77..da2777270 100644 --- a/examples/cohere/command-r-7b-qlora.yml +++ b/examples/cohere/command-r-7b-qlora.yml @@ -35,7 +35,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 4 @@ -56,3 +55,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml index 2c0495ced..1a051b98b 100644 --- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml @@ -56,3 +56,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml index de9c956e0..807342641 100644 --- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml @@ -56,3 +56,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml index 0ed97db36..78bf6b179 100644 --- a/examples/deepseek-v2/fft-fsdp-16b.yaml +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -55,3 +55,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml index 34dbeaafe..da1d9aefd 100644 --- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -79,3 +79,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml index dc0051bd5..9d92e8662 100644 --- a/examples/devstral/devstral-small-qlora.yml +++ b/examples/devstral/devstral-small-qlora.yml @@ -62,3 +62,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml index 1dd901154..484c31fec 100644 --- a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml index 24dc7cae3..dea2a6e6d 100644 --- a/examples/falcon-h1/falcon-h1-1b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml @@ -46,7 +46,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 1 num_epochs: 4 @@ -69,3 +68,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml index 43eb1967b..b187efbf6 100644 --- a/examples/falcon-h1/falcon-h1-34b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml index 00929bbf0..4d981ad95 100644 --- a/examples/falcon-h1/falcon-h1-3b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml index e2640de7b..5ee13facd 100644 --- a/examples/falcon-h1/falcon-h1-500m-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml index 183e423b5..4b665c3cd 100644 --- a/examples/falcon-h1/falcon-h1-7b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml @@ -69,3 +69,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml index cb96a32c1..68d213fad 100644 --- a/examples/gemma2/qlora.yml +++ b/examples/gemma2/qlora.yml @@ -60,3 +60,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml index ce01a4572..624ebdcd2 100644 --- a/examples/gemma2/reward-model.yaml +++ b/examples/gemma2/reward-model.yaml @@ -50,3 +50,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 217c887aa..99921770d 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -66,3 +66,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index d78559ae3..025cb9240 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -60,3 +60,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index 183eb88e8..e9e606b69 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -62,3 +62,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml index 86d9b43f8..8973cedd4 100644 --- a/examples/glm4/qlora-32b.yaml +++ b/examples/glm4/qlora-32b.yaml @@ -60,3 +60,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml index 2cb0eea41..494154886 100644 --- a/examples/jamba/qlora.yaml +++ b/examples/jamba/qlora.yaml @@ -54,3 +54,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml index d13ce6483..64db8f2ff 100644 --- a/examples/jamba/qlora_deepspeed.yaml +++ b/examples/jamba/qlora_deepspeed.yaml @@ -55,3 +55,5 @@ saves_per_epoch: 1 deepspeed: deepspeed_configs/zero2.json weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml index 6badaba19..fda30e2d2 100644 --- a/examples/jamba/qlora_fsdp_large.yaml +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -64,3 +64,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/lfm2/lfm2-350m-fft.yaml b/examples/lfm2/lfm2-350m-fft.yaml index 95961557e..74c90c1e1 100644 --- a/examples/lfm2/lfm2-350m-fft.yaml +++ b/examples/lfm2/lfm2-350m-fft.yaml @@ -46,3 +46,5 @@ evals_per_epoch: 2 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index 86b1b6a21..c44cd2230 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -55,3 +55,5 @@ saves_per_epoch: 1 deepspeed: #deepspeed_configs/zero2.json # multi-gpu only weight_decay: 0.1 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index 0f1b34016..580fabdf8 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -64,3 +64,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml index a76a792ae..a44e261be 100644 --- a/examples/llama-2/lisa.yml +++ b/examples/llama-2/lisa.yml @@ -60,3 +60,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 22dbf2d99..085627f63 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -52,3 +52,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 679aed3a9..759fce044 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -52,3 +52,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index a42eabd4b..3bf30120b 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -67,3 +67,5 @@ fsdp_config: fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer fsdp_state_dict_type: FULL_STATE_DICT special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index de65928bc..09596c71e 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -53,3 +53,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index e0a5f7068..ca8b14a1c 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -58,3 +58,5 @@ special_tokens: bos_token: "" eos_token: "" unk_token: "" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml index 2b0ae2c70..64d749b5a 100644 --- a/examples/llama-3-vision/lora-11b.yaml +++ b/examples/llama-3-vision/lora-11b.yaml @@ -57,3 +57,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml index 5d979c96c..08d8ee5c1 100644 --- a/examples/llama-3/3b-qat-fsdp2.yaml +++ b/examples/llama-3/3b-qat-fsdp2.yaml @@ -77,3 +77,5 @@ fsdp_config: special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index eccfa6d8c..e2808935f 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -72,3 +72,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index fdae3e6c4..2dfe6d492 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -42,3 +42,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index 51f1c768b..10ab2a320 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -71,3 +71,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index acab862f6..83b7f9a37 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -64,3 +64,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml index 10e9747cb..b20dbad84 100644 --- a/examples/llama-3/lora-1b-deduplicate-dpo.yml +++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml @@ -83,3 +83,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml index 630ec92f6..67e518184 100644 --- a/examples/llama-3/lora-1b-deduplicate-sft.yml +++ b/examples/llama-3/lora-1b-deduplicate-sft.yml @@ -61,3 +61,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml index a2d07ca49..92a948c2e 100644 --- a/examples/llama-3/lora-1b-kernels.yml +++ b/examples/llama-3/lora-1b-kernels.yml @@ -65,3 +65,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml index bb23164eb..178a1fb89 100644 --- a/examples/llama-3/lora-1b-ray.yml +++ b/examples/llama-3/lora-1b-ray.yml @@ -64,3 +64,5 @@ special_tokens: use_ray: true ray_num_workers: 4 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml index 769dd32e6..c4ce3eb0f 100644 --- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml +++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml @@ -63,3 +63,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml index acc17e21f..82085483f 100644 --- a/examples/llama-3/lora-1b.yml +++ b/examples/llama-3/lora-1b.yml @@ -60,3 +60,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml index ad50cd38a..c39389755 100644 --- a/examples/llama-3/lora-8b.yml +++ b/examples/llama-3/lora-8b.yml @@ -57,3 +57,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index 89a51ea68..f156e23d3 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -61,3 +61,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml index 5c8fe6628..6b76ea8d9 100644 --- a/examples/llama-3/qlora-1b.yml +++ b/examples/llama-3/qlora-1b.yml @@ -62,3 +62,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml index 2b7d51925..1ee922b59 100644 --- a/examples/llama-3/qlora-fsdp-405b.yaml +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -60,3 +60,5 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD special_tokens: pad_token: <|finetune_right_pad_id|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml index 412b6721c..5edd8353a 100644 --- a/examples/llama-3/qlora-fsdp-70b.yaml +++ b/examples/llama-3/qlora-fsdp-70b.yaml @@ -69,3 +69,5 @@ fsdp_config: fsdp_sharding_strategy: FULL_SHARD special_tokens: pad_token: <|end_of_text|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml index 4cc9fc3db..a674eca27 100644 --- a/examples/llama-3/qlora.yml +++ b/examples/llama-3/qlora.yml @@ -54,3 +54,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml index 1bbb88028..8577a19d2 100644 --- a/examples/llama-3/sparse-finetuning.yaml +++ b/examples/llama-3/sparse-finetuning.yaml @@ -75,3 +75,5 @@ llmcompressor: ] start: 0 save_compressed: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml index 2be94f4ef..d4a038e11 100644 --- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml @@ -86,3 +86,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml index eeae872a6..bea10d979 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml @@ -90,3 +90,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml index 17ad70634..737d93812 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml @@ -83,3 +83,5 @@ weight_decay: 0.0 special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml index eff708e4d..390be5af7 100644 --- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml @@ -86,3 +86,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml index 9a411883e..b319349c4 100644 --- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml +++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml @@ -84,3 +84,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml index 20352f81e..6be3988ef 100644 --- a/examples/llama-4/scout-qlora-single-h100-flex.yaml +++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml @@ -82,3 +82,5 @@ weight_decay: 0.0 special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml index 9fbd34107..a67936cf1 100644 --- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml +++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml @@ -87,3 +87,5 @@ fsdp_config: special_tokens: pad_token: <|finetune_right_pad_id|> eos_token: <|eot|> + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml index 5198c8e74..a4bac8987 100644 --- a/examples/llava/lora-7b.yaml +++ b/examples/llava/lora-7b.yaml @@ -53,3 +53,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml index b10e8baf6..b23d2309a 100644 --- a/examples/magistral/magistral-small-fsdp-qlora.yaml +++ b/examples/magistral/magistral-small-fsdp-qlora.yaml @@ -70,3 +70,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer fsdp_activation_checkpointing: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml index e3e746f22..f0fce014f 100644 --- a/examples/magistral/magistral-small-qlora.yaml +++ b/examples/magistral/magistral-small-qlora.yaml @@ -61,3 +61,5 @@ flash_attention: true warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 3d4583932..2261bd215 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -48,3 +48,5 @@ weight_decay: 0.0 special_tokens: tokens: save_safetensors: False + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml index f626a92a1..e9bcbb7d6 100644 --- a/examples/mistral/bigstral-ds-zero3.yaml +++ b/examples/mistral/bigstral-ds-zero3.yaml @@ -53,3 +53,5 @@ special_tokens: eos_token: "<|im_end|>" tokens: - "<|im_start|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index 15edffb44..8c4d80f79 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -43,3 +43,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml index e6f46affb..d54c3e30b 100644 --- a/examples/mistral/lora-mps.yml +++ b/examples/mistral/lora-mps.yml @@ -64,3 +64,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml index 9af4274fd..161255468 100644 --- a/examples/mistral/lora.yml +++ b/examples/mistral/lora.yml @@ -64,3 +64,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/mistral-dpo-qlora.yml index af707973f..8d0378690 100644 --- a/examples/mistral/mistral-dpo-qlora.yml +++ b/examples/mistral/mistral-dpo-qlora.yml @@ -80,3 +80,5 @@ weight_decay: 0.0 special_tokens: bos_token: "<|im_start|>" eos_token: "<|im_end|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml index e234b19a2..cec958c54 100644 --- a/examples/mistral/mistral-qlora-fsdp.yml +++ b/examples/mistral/mistral-qlora-fsdp.yml @@ -74,3 +74,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml index 6c0212b7c..f37dc09fa 100644 --- a/examples/mistral/mistral-qlora-orpo.yml +++ b/examples/mistral/mistral-qlora-orpo.yml @@ -69,3 +69,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml index 3e3b45862..4a492c595 100644 --- a/examples/mistral/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral/mistral-small-3.1-24B-lora.yml @@ -56,3 +56,5 @@ evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml index af6ba5a76..64ef9930c 100644 --- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml +++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml @@ -72,3 +72,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml index b1843a138..c8d0a2711 100644 --- a/examples/mistral/mixtral-qlora-fsdp.yml +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -77,3 +77,5 @@ fsdp_config: fsdp_forward_prefetch: false fsdp_backward_prefetch: BACKWARD_PRE special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 4c256420c..5be9b4db8 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -81,3 +81,5 @@ saves_per_epoch: 1 deepspeed: deepspeed_configs/zero2.json weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml index 25e1d7155..100e4464f 100644 --- a/examples/mistral/mixtral_22.yml +++ b/examples/mistral/mixtral_22.yml @@ -51,3 +51,5 @@ special_tokens: eos_token: "<|im_end|>" tokens: - "<|im_start|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 607e33701..08df36e15 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -64,3 +64,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml index 9bcbbeee0..57f65d966 100644 --- a/examples/orpheus/finetune.yml +++ b/examples/orpheus/finetune.yml @@ -50,3 +50,5 @@ weight_decay: 0.05 special_tokens: pad_token: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/lora-3.5.yaml b/examples/phi/lora-3.5.yaml index ad4ce9cd4..9f3bbdf53 100644 --- a/examples/phi/lora-3.5.yaml +++ b/examples/phi/lora-3.5.yaml @@ -63,3 +63,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 4 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index 1562a7353..fc6d649d7 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -57,3 +57,5 @@ weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index 4cd53db97..ccd92c817 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -60,3 +60,5 @@ weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index ca733cc71..853250ccb 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -57,3 +57,5 @@ weight_decay: 0.1 resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml index d0d14fea6..130298bc0 100644 --- a/examples/phi/phi3-ft-fsdp.yml +++ b/examples/phi/phi3-ft-fsdp.yml @@ -71,3 +71,5 @@ fsdp_config: resize_token_embeddings_to_32x: true special_tokens: pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml index 17c48da6f..42b87e8d0 100644 --- a/examples/phi/phi3-ft.yml +++ b/examples/phi/phi3-ft.yml @@ -59,3 +59,5 @@ warmup_ratio: 0.2 debug: true weight_decay: 0.1 resize_token_embeddings_to_32x: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index 6ad0a5e99..ea769d202 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -55,3 +55,5 @@ saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml index e8932b968..8ea608199 100644 --- a/examples/qwen2-vl/lora-7b.yaml +++ b/examples/qwen2-vl/lora-7b.yaml @@ -53,3 +53,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index bd896c2b3..69a74ae4a 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -54,3 +54,5 @@ warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml index 4afa24f3c..af188f75d 100644 --- a/examples/qwen2/prm.yaml +++ b/examples/qwen2/prm.yaml @@ -55,3 +55,5 @@ eval_steps: 100 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index ed2670ab6..861ce5517 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -67,3 +67,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml index 822407a1f..1854b8216 100644 --- a/examples/qwen2/reward-model.yaml +++ b/examples/qwen2/reward-model.yaml @@ -26,7 +26,6 @@ wandb_watch: wandb_name: wandb_log_model: - gradient_accumulation_steps: 4 micro_batch_size: 2 num_epochs: 4 @@ -50,3 +49,5 @@ evals_per_epoch: saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml index 25d02805f..13a97dec3 100644 --- a/examples/qwen2_5-vl/lora-7b.yaml +++ b/examples/qwen2_5-vl/lora-7b.yaml @@ -53,3 +53,5 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml index 45a4395ac..1f148ece5 100644 --- a/examples/qwen3/32b-qlora.yaml +++ b/examples/qwen3/32b-qlora.yaml @@ -67,3 +67,5 @@ evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml index 6832b6af7..e4d0ed4fb 100644 --- a/examples/qwen3/8b-qat-fsdp2.yml +++ b/examples/qwen3/8b-qat-fsdp2.yml @@ -76,3 +76,5 @@ fsdp_config: fsdp_activation_checkpointing: true special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml index dc3377b4f..762f9648d 100644 --- a/examples/qwen3/qlora-fsdp.yaml +++ b/examples/qwen3/qlora-fsdp.yaml @@ -66,3 +66,5 @@ fsdp_config: fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD special_tokens: + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 4df010040..d3a3b3242 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -36,6 +36,7 @@ from axolotl.utils.callbacks import ( GCCallback, GPUStatsCallback, SaveAxolotlConfigtoWandBCallback, + SaveModelOnFirstStepCallback, ) from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.schemas.enums import CustomSupportedOptimizers @@ -135,6 +136,8 @@ class TrainerBuilderBase(abc.ABC): callbacks.append( SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path) ) + if self.cfg.save_first_step: + callbacks.append(SaveModelOnFirstStepCallback()) callbacks.append(GPUStatsCallback(cfg=self.cfg)) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 5f804d6af..bb777fc90 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -64,7 +64,7 @@ class SaveBetterTransformerModelCallback( state: TrainerState, control: TrainerControl, **kwargs, - ): + ) -> TrainerControl: # Save if ( args.save_strategy == IntervalStrategy.STEPS @@ -100,11 +100,11 @@ class GPUStatsCallback( def on_step_end( self, - args: TrainingArguments, + args: TrainingArguments, # pylint: disable=unused-argument state: TrainerState, control: TrainerControl, **kwargs, - ): + ) -> TrainerControl: if not self.logged and state.global_step > 1: log_gpu_memory_usage(LOG, "while training", self.cfg.device) self.logged = True @@ -116,18 +116,17 @@ class LossWatchDogCallback(TrainerCallback): def __init__(self, cfg): self.cfg = cfg - self.logged = False self.violations = 0 self.threshold = cfg.loss_watchdog_threshold self.patience = cfg.loss_watchdog_patience or 3 def on_step_end( self, - _args: TrainingArguments, + args: TrainingArguments, # pylint: disable=unused-argument state: TrainerState, control: TrainerControl, **_kwargs, - ): + ) -> TrainerControl: if len(state.log_history) > 0 and "loss" in state.log_history[-1]: if state.log_history[-1]["loss"] > self.threshold: self.violations += 1 @@ -141,6 +140,21 @@ class LossWatchDogCallback(TrainerCallback): return control +class SaveModelOnFirstStepCallback(TrainerCallback): + """Callback to save the model on the first step of training if enabled""" + + def on_step_end( + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, + **_kwargs, + ) -> TrainerControl: + if state.global_step == 1: + control.should_save = True + return control + + def bench_eval_callback_factory(trainer, tokenizer): accuracy = evaluate.load("accuracy") abcd_idx = [ diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 909fd637c..e20cdaf47 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -706,6 +706,7 @@ class AxolotlInputConfig( "description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`" }, ) + save_steps: int | float | None = Field( default=None, json_schema_extra={ @@ -727,6 +728,13 @@ class AxolotlInputConfig( save_total_limit: int | None = Field( default=None, json_schema_extra={"description": "Checkpoints saved at a time"} ) + save_first_step: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to checkpoint a model after the first step of training. Defaults to False." + }, + ) + logging_steps: int | None = Field( default=None, json_schema_extra={"description": "Logging frequency"} ) diff --git a/tests/e2e/integrations/test_cut_cross_entropy.py b/tests/e2e/integrations/test_cut_cross_entropy.py index 790b34f3e..34e6c9644 100644 --- a/tests/e2e/integrations/test_cut_cross_entropy.py +++ b/tests/e2e/integrations/test_cut_cross_entropy.py @@ -44,6 +44,7 @@ def min_cfg(temp_dir): "save_safetensors": True, "max_steps": 10, "bf16": "auto", + "save_first_step": False, } @@ -98,6 +99,7 @@ class TestCutCrossEntropyIntegration: "save_safetensors": True, "max_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/integrations/test_hooks.py b/tests/e2e/integrations/test_hooks.py index 4734449fe..8743efb98 100644 --- a/tests/e2e/integrations/test_hooks.py +++ b/tests/e2e/integrations/test_hooks.py @@ -153,6 +153,7 @@ class TestPluginHooks: "max_steps": 5, "flash_attention": True, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 212450e89..1ac3b537e 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -67,6 +67,7 @@ def min_cfg(temp_dir): "output_dir": temp_dir, "save_safetensors": True, "use_tensorboard": True, + "save_first_step": False, } diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 6ab3d7ab8..b1f5befdd 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -50,6 +50,7 @@ class LigerIntegrationTestCase: "save_safetensors": True, "bf16": "auto", "max_steps": 5, + "save_first_step": False, } ) # pylint: disable=duplicate-code @@ -96,6 +97,7 @@ class LigerIntegrationTestCase: "save_safetensors": True, "bf16": "auto", "max_steps": 5, + "save_first_step": False, } ) # pylint: disable=duplicate-code diff --git a/tests/e2e/integrations/test_llm_compressor.py b/tests/e2e/integrations/test_llm_compressor.py index 247ae3bac..dceecea9f 100644 --- a/tests/e2e/integrations/test_llm_compressor.py +++ b/tests/e2e/integrations/test_llm_compressor.py @@ -81,6 +81,7 @@ class TestLLMCompressorIntegration: }, "save_compressed": save_compressed, }, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index 5593c7eb6..80098e684 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -69,6 +69,7 @@ class TestSequenceParallelism: "use_tensorboard": True, "sequence_parallel_degree": 2, "ring_attn_func": ring_attn_func, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py index bdf5ada6b..cbdf8de96 100644 --- a/tests/e2e/multigpu/solo/test_flex.py +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -61,6 +61,7 @@ class TestPackedFlex: "max_steps": 2, "use_tensorboard": True, "save_strategy": "no", + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index c04734345..d022ae2d9 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -223,6 +223,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -317,6 +318,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -409,6 +411,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_eval.py b/tests/e2e/multigpu/test_eval.py index d6429cf63..4f86278ff 100644 --- a/tests/e2e/multigpu/test_eval.py +++ b/tests/e2e/multigpu/test_eval.py @@ -67,6 +67,7 @@ class TestMultiGPUEval: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, + "save_first_step": False, } ) @@ -138,6 +139,7 @@ class TestMultiGPUEval: "logging_steps": 1, "weight_decay": 0.0, "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_gemma3.py b/tests/e2e/multigpu/test_gemma3.py index 3868d90f0..4a7b101a8 100644 --- a/tests/e2e/multigpu/test_gemma3.py +++ b/tests/e2e/multigpu/test_gemma3.py @@ -71,6 +71,7 @@ class TestMultiGPUGemma3: "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index f0c74fbf8..aab14dcc4 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -69,6 +69,7 @@ class TestMultiGPULlama: "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -135,6 +136,7 @@ class TestMultiGPULlama: "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -210,6 +212,7 @@ class TestMultiGPULlama: "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -289,6 +292,7 @@ class TestMultiGPULlama: "flash_attention": True, "use_tensorboard": True, "bf16": True, + "save_first_step": False, } ) @@ -365,6 +369,7 @@ class TestMultiGPULlama: }, "use_tensorboard": True, "seed": 42, + "save_first_step": False, } ) @@ -442,6 +447,7 @@ class TestMultiGPULlama: "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", }, "use_tensorboard": True, + "save_first_step": False, } ) @@ -520,6 +526,7 @@ class TestMultiGPULlama: "fsdp_reshard_after_forward": fsdp_reshard_after_forward, }, "use_tensorboard": True, + "save_first_step": False, } ) if attention_backend == "flash": @@ -605,6 +612,7 @@ class TestMultiGPULlama: "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", }, "use_tensorboard": True, + "save_first_step": False, } ) @@ -689,6 +697,7 @@ class TestMultiGPULlama: "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / deepspeed), "use_tensorboard": True, + "save_first_step": False, **adapter, } ) @@ -765,6 +774,7 @@ class TestMultiGPULlama: "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), "use_tensorboard": True, "seed": 42, + "save_first_step": False, **adapter, } ) @@ -840,6 +850,7 @@ class TestMultiGPULlama: "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), "use_tensorboard": True, + "save_first_step": False, **adapter, } ) @@ -908,6 +919,7 @@ class TestMultiGPULlama: "save_safetensors": True, # "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"), "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index 43a722b48..dd1422296 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -56,6 +56,7 @@ class TestMultiGPURay: "use_tensorboard": True, "use_ray": True, "ray_num_workers": 2, + "save_first_step": False, } ) @@ -115,6 +116,7 @@ class TestMultiGPURay: "flash_attention": True, "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"), "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 08b62accc..1824443e7 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -55,6 +55,7 @@ class Test4dMultipackLlama(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "fp16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -102,6 +103,7 @@ class Test4dMultipackLlama(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "fp16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index d494ed1eb..3d5b3dc56 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -69,6 +69,7 @@ class TestActivationCheckpointing: "bf16": True, "save_safetensors": True, "gradient_checkpointing": gradient_checkpointing, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_fa_xentropy.py b/tests/e2e/patched/test_fa_xentropy.py index ca8b21178..38099b220 100644 --- a/tests/e2e/patched/test_fa_xentropy.py +++ b/tests/e2e/patched/test_fa_xentropy.py @@ -62,6 +62,7 @@ class TestFAXentropyLlama: "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_falcon_samplepack.py b/tests/e2e/patched/test_falcon_samplepack.py index a593b0791..ef31b11c7 100644 --- a/tests/e2e/patched/test_falcon_samplepack.py +++ b/tests/e2e/patched/test_falcon_samplepack.py @@ -58,6 +58,7 @@ class TestFalconPatched(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -99,6 +100,7 @@ class TestFalconPatched(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_flattening.py b/tests/e2e/patched/test_flattening.py index f77a1fbe5..fdaab558d 100644 --- a/tests/e2e/patched/test_flattening.py +++ b/tests/e2e/patched/test_flattening.py @@ -61,6 +61,7 @@ class TestFAFlattening: "optimizer": "adamw_8bit", "lr_scheduler": "cosine", "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_fused_llama.py b/tests/e2e/patched/test_fused_llama.py index 1bbc82a38..a3fe591ee 100644 --- a/tests/e2e/patched/test_fused_llama.py +++ b/tests/e2e/patched/test_fused_llama.py @@ -53,6 +53,7 @@ class TestFusedLlama(unittest.TestCase): "max_steps": 10, "save_steps": 5, "eval_steps": 5, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_llama_s2_attention.py b/tests/e2e/patched/test_llama_s2_attention.py index d2dcc5e4b..ba5556a59 100644 --- a/tests/e2e/patched/test_llama_s2_attention.py +++ b/tests/e2e/patched/test_llama_s2_attention.py @@ -58,6 +58,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): "save_steps": 5, "eval_steps": 5, "bf16": "auto", + "save_first_step": False, } ) @@ -100,6 +101,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase): "save_steps": 5, "eval_steps": 5, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_lora_llama_multipack.py b/tests/e2e/patched/test_lora_llama_multipack.py index 5df6bfecc..fdf6adbc6 100644 --- a/tests/e2e/patched/test_lora_llama_multipack.py +++ b/tests/e2e/patched/test_lora_llama_multipack.py @@ -55,6 +55,7 @@ class TestLoraLlama(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -108,6 +109,7 @@ class TestLoraLlama(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index 442089bae..bea0f9c68 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -56,6 +56,7 @@ class TestMistral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -97,6 +98,7 @@ class TestMistral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 5f778660b..09e427abd 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -52,6 +52,7 @@ class TestMixtral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -90,6 +91,7 @@ class TestMixtral(unittest.TestCase): "save_steps": 3, "eval_steps": 4, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_model_patches.py b/tests/e2e/patched/test_model_patches.py index 5ea88b001..b90be23e4 100644 --- a/tests/e2e/patched/test_model_patches.py +++ b/tests/e2e/patched/test_model_patches.py @@ -45,6 +45,7 @@ class TestModelPatches(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -78,6 +79,7 @@ class TestModelPatches(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/patched/test_peft_embeddings.py b/tests/e2e/patched/test_peft_embeddings.py index d4f59a128..4769319ae 100644 --- a/tests/e2e/patched/test_peft_embeddings.py +++ b/tests/e2e/patched/test_peft_embeddings.py @@ -49,6 +49,7 @@ class TestLlamaPeftEmbeddings: "bf16": "auto", "save_safetensors": True, "embeddings_skip_upcast": True, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index d241ce185..1f0ddd630 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -54,6 +54,7 @@ class TestPhiMultipack(unittest.TestCase): "eval_steps": 3, "save_steps": 4, "bf16": "auto", + "save_first_step": False, } ) @@ -105,6 +106,7 @@ class TestPhiMultipack(unittest.TestCase): "eval_steps": 3, "save_steps": 4, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 363956733..54b8245ee 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -58,6 +58,7 @@ class TestResumeLlama: "max_steps": 15, "use_tensorboard": True, "save_safetensors": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 2b4d11b30..4a2c69d45 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -47,6 +47,7 @@ def fixture_cfg(): "special_tokens": { "pad_token": "<|endoftext|>", }, + "save_first_step": False, } ) diff --git a/tests/e2e/patched/test_unsloth_qlora.py b/tests/e2e/patched/test_unsloth_qlora.py index 69171481c..2c8ee4eb0 100644 --- a/tests/e2e/patched/test_unsloth_qlora.py +++ b/tests/e2e/patched/test_unsloth_qlora.py @@ -62,6 +62,7 @@ class TestUnslothQLoRA: "lr_scheduler": "cosine", "use_tensorboard": True, "bf16": "auto", + "save_first_step": False, } ) @@ -112,6 +113,7 @@ class TestUnslothQLoRA: "lr_scheduler": "cosine", "use_tensorboard": True, "bf16": "auto", + "save_first_step": False, } ) @@ -167,6 +169,7 @@ class TestUnslothQLoRA: "lr_scheduler": "cosine", "use_tensorboard": True, "fp16": True, + "save_first_step": False, } ) diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py index 279913713..76364fc0e 100644 --- a/tests/e2e/solo/test_flex.py +++ b/tests/e2e/solo/test_flex.py @@ -49,6 +49,7 @@ class TestPackedFlex(unittest.TestCase): "lr_scheduler": "cosine", "max_steps": 5, "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index 7af550496..f6fcad841 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -65,6 +65,7 @@ class TestReLoraLlama(unittest.TestCase): "lr_scheduler": "cosine", "save_safetensors": True, "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_deepseekv3.py b/tests/e2e/test_deepseekv3.py index 7dfc4ae15..e4a47fb0a 100644 --- a/tests/e2e/test_deepseekv3.py +++ b/tests/e2e/test_deepseekv3.py @@ -67,6 +67,7 @@ class TestDeepseekV3: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -116,6 +117,7 @@ class TestDeepseekV3: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 2cdb57689..a1df69535 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -56,6 +56,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -105,6 +106,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -154,6 +156,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -203,6 +206,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -251,6 +255,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -302,6 +307,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) @@ -370,6 +376,7 @@ class TestDPOLlamaLora(unittest.TestCase): "warmup_steps": 5, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, } ) diff --git a/tests/e2e/test_embeddings_lr.py b/tests/e2e/test_embeddings_lr.py index 9b65f8feb..e4a06ad14 100644 --- a/tests/e2e/test_embeddings_lr.py +++ b/tests/e2e/test_embeddings_lr.py @@ -48,6 +48,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) @@ -93,6 +94,7 @@ class TestEmbeddingsLrScale(unittest.TestCase): "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_evaluate.py b/tests/e2e/test_evaluate.py index 6271bba28..977497e5e 100644 --- a/tests/e2e/test_evaluate.py +++ b/tests/e2e/test_evaluate.py @@ -36,6 +36,7 @@ class TestE2eEvaluate: "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 20, + "save_first_step": False, } ) diff --git a/tests/e2e/test_falcon.py b/tests/e2e/test_falcon.py index 4f88e740c..5be6efcf6 100644 --- a/tests/e2e/test_falcon.py +++ b/tests/e2e/test_falcon.py @@ -60,6 +60,7 @@ class TestFalcon(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) @@ -115,6 +116,7 @@ class TestFalcon(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) @@ -156,6 +158,7 @@ class TestFalcon(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) diff --git a/tests/e2e/test_gemma3_text.py b/tests/e2e/test_gemma3_text.py index 3f00a1384..ef38d028d 100644 --- a/tests/e2e/test_gemma3_text.py +++ b/tests/e2e/test_gemma3_text.py @@ -63,6 +63,7 @@ class TestGemma3Text: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -113,6 +114,7 @@ class TestGemma3Text: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 2b180029c..1e6df0be9 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -45,6 +45,7 @@ class TestLlama: "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) @@ -92,6 +93,7 @@ class TestLlama: "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) @@ -136,6 +138,7 @@ class TestLlama: "sample_packing": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) @@ -176,6 +179,7 @@ class TestLlama: "batch_flattening": True, "bf16": True, "save_safetensors": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_llama_pretrain.py b/tests/e2e/test_llama_pretrain.py index fdebf2173..bd5502300 100644 --- a/tests/e2e/test_llama_pretrain.py +++ b/tests/e2e/test_llama_pretrain.py @@ -53,6 +53,7 @@ class TestPretrainLlama: "save_safetensors": True, "bf16": "auto", "use_tensorboard": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_llama_vision.py b/tests/e2e/test_llama_vision.py index ad4a83c6a..760759bca 100644 --- a/tests/e2e/test_llama_vision.py +++ b/tests/e2e/test_llama_vision.py @@ -54,6 +54,7 @@ class TestLlamaVision(unittest.TestCase): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) @@ -100,6 +101,7 @@ class TestLlamaVision(unittest.TestCase): "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index 301565302..7e0ff46cf 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -49,6 +49,7 @@ class TestLoraLlama(unittest.TestCase): "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", "max_steps": 5, + "save_first_step": False, } ) diff --git a/tests/e2e/test_mamba.py b/tests/e2e/test_mamba.py index 1824619a6..73d3bdc26 100644 --- a/tests/e2e/test_mamba.py +++ b/tests/e2e/test_mamba.py @@ -51,6 +51,7 @@ class TestMamba(unittest.TestCase): "save_steps": 10, "eval_steps": None, "save_safetensors": False, + "save_first_step": False, } ) diff --git a/tests/e2e/test_mistral.py b/tests/e2e/test_mistral.py index 5d9b8ba8c..f47f794e0 100644 --- a/tests/e2e/test_mistral.py +++ b/tests/e2e/test_mistral.py @@ -55,6 +55,7 @@ class TestMistral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -95,6 +96,7 @@ class TestMistral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/test_mixtral.py b/tests/e2e/test_mixtral.py index 761e59391..3fe2bf70f 100644 --- a/tests/e2e/test_mixtral.py +++ b/tests/e2e/test_mixtral.py @@ -61,6 +61,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -116,6 +117,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -170,6 +172,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): @@ -228,6 +231,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) @@ -273,6 +277,7 @@ class TestMixtral(unittest.TestCase): "max_steps": 20, "save_steps": 10, "eval_steps": 10, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 53ef86022..1d233a201 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -55,6 +55,7 @@ class TestCustomOptimizers(unittest.TestCase): "optimizer": "optimi_adamw", "max_steps": 5, "lr_scheduler": "cosine", + "save_first_step": False, } ) @@ -100,6 +101,7 @@ class TestCustomOptimizers(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adopt_adamw", "lr_scheduler": "cosine", + "save_first_step": False, } ) @@ -146,6 +148,7 @@ class TestCustomOptimizers(unittest.TestCase): "optimizer": "muon", "lr_scheduler": "cosine", "weight_decay": 0.01, + "save_first_step": False, } ) @@ -184,6 +187,7 @@ class TestCustomOptimizers(unittest.TestCase): "lr_scheduler": "constant", "save_safetensors": True, "max_steps": 10, + "save_first_step": False, } ) # pylint: disable=duplicate-code @@ -232,6 +236,7 @@ class TestCustomOptimizers(unittest.TestCase): "adam_epsilon2": 1e-16, "max_steps": 5, "lr_scheduler": "cosine", + "save_first_step": False, } ) diff --git a/tests/e2e/test_packing_loss.py b/tests/e2e/test_packing_loss.py index cc2db72e0..aec9d95f8 100644 --- a/tests/e2e/test_packing_loss.py +++ b/tests/e2e/test_packing_loss.py @@ -48,6 +48,7 @@ class TestPackedLlama(unittest.TestCase): "lr_scheduler": "cosine", "max_steps": 5, "use_tensorboard": True, + "save_first_step": False, } ) if is_torch_bf16_gpu_available(): diff --git a/tests/e2e/test_phi.py b/tests/e2e/test_phi.py index 88fda9191..ab3a63674 100644 --- a/tests/e2e/test_phi.py +++ b/tests/e2e/test_phi.py @@ -53,6 +53,7 @@ class TestPhi(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -102,6 +103,7 @@ class TestPhi(unittest.TestCase): "save_steps": 10, "eval_steps": 10, "bf16": "auto", + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_process_reward_model_smollm2.py b/tests/e2e/test_process_reward_model_smollm2.py index abfe1b0c5..bd9eec48b 100644 --- a/tests/e2e/test_process_reward_model_smollm2.py +++ b/tests/e2e/test_process_reward_model_smollm2.py @@ -49,6 +49,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase): "use_tensorboard": True, "special_tokens": {"pad_token": "<|endoftext|>"}, "seed": 42, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_qat.py b/tests/e2e/test_qat.py index ef726079d..139ae155a 100644 --- a/tests/e2e/test_qat.py +++ b/tests/e2e/test_qat.py @@ -57,6 +57,7 @@ class TestQATLlama: "max_steps": 5, "save_safetensors": True, "bf16": True, + "save_first_step": False, } ) cfg = validate_config(cfg) @@ -115,6 +116,7 @@ class TestQATLlama: "weight_dtype": "int8", "group_size": 8, }, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_qwen.py b/tests/e2e/test_qwen.py index aa8b9f6c0..59267d14d 100644 --- a/tests/e2e/test_qwen.py +++ b/tests/e2e/test_qwen.py @@ -59,6 +59,7 @@ class TestE2eQwen: "bf16": "auto", "tf32": True, "gradient_checkpointing": True, + "save_first_step": False, } ) diff --git a/tests/e2e/test_reward_model_smollm2.py b/tests/e2e/test_reward_model_smollm2.py index 5d52bcc86..82513f99f 100644 --- a/tests/e2e/test_reward_model_smollm2.py +++ b/tests/e2e/test_reward_model_smollm2.py @@ -58,6 +58,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase): "gradient_checkpointing": True, "warmup_ratio": 0.1, "use_tensorboard": True, + "save_first_step": False, } ) cfg = validate_config(cfg) diff --git a/tests/e2e/test_save_first_step.py b/tests/e2e/test_save_first_step.py new file mode 100644 index 000000000..5bbd2302b --- /dev/null +++ b/tests/e2e/test_save_first_step.py @@ -0,0 +1,102 @@ +""" +E2E tests for relora llama +""" + +import unittest +from pathlib import Path + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestSaveFirstStepCallback(unittest.TestCase): + """Test cases for save_first_step callback config.""" + + @with_temp_dir + def test_save_first_step(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 512, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "save_first_step": True, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg) + + @with_temp_dir + def test_no_save_first_step(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 512, + "val_set_size": 0.02, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 3, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + with pytest.raises(AssertionError): + check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg) diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py index e98378f08..8f7a13aee 100644 --- a/tests/e2e/test_schedulers.py +++ b/tests/e2e/test_schedulers.py @@ -51,6 +51,7 @@ class TestCustomSchedulers(unittest.TestCase): "lr_scheduler": "rex", "warmup_steps": 5, "cosine_min_lr_ratio": 0.05, + "save_first_step": False, } )