From ba47adc24b6ff1fba185b3f1bf409d652e81961a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 4 Apr 2025 23:37:30 -0400 Subject: [PATCH] replace attention in the yaml config with an enum --- examples/cerebras/btlm-ft.yml | 4 +- examples/cerebras/qlora.yml | 3 +- examples/code-llama/13b/lora.yml | 3 +- examples/code-llama/13b/qlora.yml | 3 +- examples/code-llama/34b/lora.yml | 3 +- examples/code-llama/34b/qlora.yml | 3 +- examples/code-llama/7b/lora.yml | 3 +- examples/code-llama/7b/qlora.yml | 3 +- examples/cohere/command-r-7b-qlora.yml | 3 +- .../colab-axolotl-example.ipynb | 4 +- examples/dbrx/16bit-lora.yaml | 3 +- examples/dbrx/8bit-lora.yaml | 3 +- examples/dbrx/fft-ds-zero3.yaml | 3 +- examples/deepseek-v2/fft-fsdp-16b.yaml | 3 +- examples/deepseek-v2/qlora-fsdp-2_5.yaml | 3 +- examples/falcon/config-7b-lora.yml | 3 +- examples/falcon/config-7b-qlora.yml | 3 +- examples/falcon/config-7b.yml | 3 +- examples/gemma/qlora.yml | 3 +- examples/gemma2/qlora.yml | 3 +- examples/gemma2/reward-model.yaml | 3 +- examples/gemma3/gemma-3-1b-qlora.yml | 3 +- examples/gemma3/gemma-3-4b-qlora.yml | 3 +- examples/gemma3/gemma-3-4b-vision-qlora.yml | 3 +- examples/gptj/qlora.yml | 3 +- examples/jamba/qlora.yaml | 3 +- examples/jamba/qlora_deepspeed.yaml | 3 +- examples/jamba/qlora_fsdp_large.yaml | 3 +- examples/jeopardy-bot/config.yml | 3 +- examples/llama-2/fft_optimized.yml | 3 +- examples/llama-2/gptq-lora.yml | 4 +- examples/llama-2/lisa.yml | 3 +- examples/llama-2/loftq.yml | 3 +- examples/llama-2/lora.yml | 3 +- examples/llama-2/qlora-fsdp.yml | 3 +- examples/llama-2/qlora.yml | 3 +- examples/llama-2/relora.yml | 3 +- examples/llama-3-vision/lora-11b.yaml | 3 +- examples/llama-3/fft-8b-liger-fsdp.yaml | 3 +- examples/llama-3/fft-8b.yaml | 3 +- examples/llama-3/instruct-dpo-lora-8b.yml | 3 +- examples/llama-3/instruct-lora-8b.yml | 3 +- examples/llama-3/lora-1b-deduplicate-dpo.yml | 3 +- examples/llama-3/lora-1b-deduplicate-sft.yml | 3 +- examples/llama-3/lora-1b-kernels.yml | 3 +- examples/llama-3/lora-1b-ray.yml | 3 +- .../lora-1b-sample-packing-sequentially.yml | 3 +- examples/llama-3/lora-1b.yml | 3 +- examples/llama-3/lora-8b.yml | 3 +- examples/llama-3/qlora-1b-kto.yaml | 3 +- examples/llama-3/qlora-1b.yml | 3 +- examples/llama-3/qlora-fsdp-405b.yaml | 3 +- examples/llama-3/qlora-fsdp-70b.yaml | 3 +- examples/llama-3/qlora.yml | 3 +- examples/llava/lora-7b.yaml | 3 +- examples/mamba/config.yml | 2 +- examples/mistral/bigstral-ds-zero3.yaml | 3 +- examples/mistral/config.yml | 3 +- examples/mistral/lora-mps.yml | 3 +- examples/mistral/lora.yml | 3 +- examples/mistral/mistral-dpo-qlora.yml | 2 +- examples/mistral/mistral-qlora-fsdp.yml | 3 +- examples/mistral/mistral-qlora-orpo.yml | 3 +- .../mistral/mistral-small-3.1-24B-lora.yml | 4 +- examples/mistral/mixtral-8x22b-qlora-fsdp.yml | 3 +- examples/mistral/mixtral-qlora-fsdp.yml | 3 +- examples/mistral/mixtral.yml | 3 +- examples/mistral/mixtral_22.yml | 3 +- examples/mistral/qlora.yml | 3 +- examples/mpt-7b/config.yml | 2 +- examples/openllama-3b/config.yml | 3 +- examples/openllama-3b/lora.yml | 3 +- examples/openllama-3b/qlora.yml | 3 +- examples/phi/phi-ft.yml | 3 +- examples/phi/phi-qlora.yml | 3 +- examples/phi/phi2-ft.yml | 3 +- examples/phi/phi3-ft-fsdp.yml | 3 +- examples/phi/phi3-ft.yml | 3 +- examples/pixtral/lora-12b.yml | 3 +- examples/qwen/lora.yml | 2 +- examples/qwen/qlora.yml | 2 +- examples/qwen/qwen2-moe-lora.yaml | 3 +- examples/qwen/qwen2-moe-qlora.yaml | 3 +- examples/qwen2-vl/lora-7b.yaml | 3 +- examples/qwen2/dpo.yaml | 3 +- examples/qwen2/prm.yaml | 3 +- examples/qwen2/qlora-fsdp.yaml | 3 +- examples/qwen2/reward-model.yaml | 3 +- examples/redpajama/config-3b.yml | 2 +- examples/replit-3b/config-lora.yml | 2 +- examples/stablelm-2/1.6b/fft.yml | 3 +- examples/stablelm-2/1.6b/lora.yml | 3 +- examples/starcoder2/qlora.yml | 3 +- examples/tiny-llama/lora-mps.yml | 2 +- examples/tiny-llama/lora.yml | 3 +- examples/tiny-llama/pretrain.yml | 3 +- examples/tiny-llama/qlora.yml | 3 +- examples/xgen-7b/xgen-7b-8k-qlora.yml | 3 +- examples/yi-34B-chat/qlora.yml | 3 +- src/axolotl/utils/schemas/config.py | 86 ++++++++++++++++++- src/axolotl/utils/schemas/enums.py | 11 +++ 101 files changed, 268 insertions(+), 122 deletions(-) diff --git a/examples/cerebras/btlm-ft.yml b/examples/cerebras/btlm-ft.yml index c9878779d..50270c162 100644 --- a/examples/cerebras/btlm-ft.yml +++ b/examples/cerebras/btlm-ft.yml @@ -59,9 +59,7 @@ gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true -sdp_attention: -flash_optimum: +attention: flash gptq_groupsize: gptq_model_v1: diff --git a/examples/cerebras/qlora.yml b/examples/cerebras/qlora.yml index 55cc597f1..dd27f3a41 100644 --- a/examples/cerebras/qlora.yml +++ b/examples/cerebras/qlora.yml @@ -39,8 +39,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attention: xformers gptq_groupsize: gptq_model_v1: warmup_steps: 10 diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml index 0ed2382ba..cc13199ca 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/code-llama/13b/lora.yml @@ -45,7 +45,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/13b/qlora.yml b/examples/code-llama/13b/qlora.yml index 22bd1691b..7b68c650d 100644 --- a/examples/code-llama/13b/qlora.yml +++ b/examples/code-llama/13b/qlora.yml @@ -46,7 +46,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml index 25dc9f421..bd2fdea41 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/code-llama/34b/lora.yml @@ -45,7 +45,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/34b/qlora.yml b/examples/code-llama/34b/qlora.yml index 0e33e2a45..b450fd842 100644 --- a/examples/code-llama/34b/qlora.yml +++ b/examples/code-llama/34b/qlora.yml @@ -46,7 +46,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml index d288b9f65..e8a975f1f 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/code-llama/7b/lora.yml @@ -45,7 +45,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/code-llama/7b/qlora.yml b/examples/code-llama/7b/qlora.yml index de41c0123..ddffb2343 100644 --- a/examples/code-llama/7b/qlora.yml +++ b/examples/code-llama/7b/qlora.yml @@ -46,7 +46,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml index 4a30e9a77..cdaf06cd7 100644 --- a/examples/cohere/command-r-7b-qlora.yml +++ b/examples/cohere/command-r-7b-qlora.yml @@ -49,7 +49,8 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 0b373c28c..963d0052a 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -112,9 +112,7 @@ "early_stopping_patience:\n", "resume_from_checkpoint:\n", "logging_steps: 1\n", - "xformers_attention:\n", - "flash_attention: false\n", - "sdp_attention: true\n", + "attention: sdpa\n", "\n", "warmup_steps: 1\n", "max_steps: 25\n", diff --git a/examples/dbrx/16bit-lora.yaml b/examples/dbrx/16bit-lora.yaml index 852654d49..33ea210a1 100644 --- a/examples/dbrx/16bit-lora.yaml +++ b/examples/dbrx/16bit-lora.yaml @@ -52,7 +52,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: diff --git a/examples/dbrx/8bit-lora.yaml b/examples/dbrx/8bit-lora.yaml index 0b9402194..322f79687 100644 --- a/examples/dbrx/8bit-lora.yaml +++ b/examples/dbrx/8bit-lora.yaml @@ -55,7 +55,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: diff --git a/examples/dbrx/fft-ds-zero3.yaml b/examples/dbrx/fft-ds-zero3.yaml index e42c16673..8e26fbe1a 100644 --- a/examples/dbrx/fft-ds-zero3.yaml +++ b/examples/dbrx/fft-ds-zero3.yaml @@ -39,7 +39,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml index 0ed97db36..09954f4e9 100644 --- a/examples/deepseek-v2/fft-fsdp-16b.yaml +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -35,7 +35,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 2 diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml index 34dbeaafe..437f30cfc 100644 --- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -59,7 +59,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 2 diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index 391d4dd94..d963fff38 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -43,8 +43,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attention: xformers gptq_groupsize: gptq_model_v1: warmup_steps: 40 diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index a9af8574c..b2447652f 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -73,8 +73,7 @@ early_stopping_patience: 3 resume_from_checkpoint: auto_resume_from_checkpoints: true logging_steps: 1 -xformers_attention: true -flash_attention: +attention: xformers gptq_groupsize: gptq_model_v1: warmup_steps: 10 diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index 3cc553daa..821e1f5db 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -40,8 +40,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attention: xformers gptq_groupsize: gptq_model_v1: warmup_steps: 40 diff --git a/examples/gemma/qlora.yml b/examples/gemma/qlora.yml index 2738112b4..a3b022a86 100644 --- a/examples/gemma/qlora.yml +++ b/examples/gemma/qlora.yml @@ -47,7 +47,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml index cb96a32c1..2ca2c9c44 100644 --- a/examples/gemma2/qlora.yml +++ b/examples/gemma2/qlora.yml @@ -53,7 +53,8 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml index ce01a4572..6d9f51bce 100644 --- a/examples/gemma2/reward-model.yaml +++ b/examples/gemma2/reward-model.yaml @@ -43,7 +43,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 44310558c..6796cf98c 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -57,7 +57,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 29f8cc1e1..b4c1b1232 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -51,8 +51,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -flash_attention: true -eager_attention: +attention: flash warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index 3fd9eb5f0..53c40ec18 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -53,8 +53,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -flash_attention: true -eager_attention: +attention: flash warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/gptj/qlora.yml b/examples/gptj/qlora.yml index c3cf9f973..40be1cabc 100644 --- a/examples/gptj/qlora.yml +++ b/examples/gptj/qlora.yml @@ -36,8 +36,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attention: xformers gptq_groupsize: gptq_model_v1: warmup_steps: 10 diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml index 2cb0eea41..356be5d2f 100644 --- a/examples/jamba/qlora.yaml +++ b/examples/jamba/qlora.yaml @@ -47,7 +47,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml index d13ce6483..47ff4aee1 100644 --- a/examples/jamba/qlora_deepspeed.yaml +++ b/examples/jamba/qlora_deepspeed.yaml @@ -46,7 +46,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml index 6badaba19..73dd10b8d 100644 --- a/examples/jamba/qlora_fsdp_large.yaml +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -45,7 +45,8 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 1 diff --git a/examples/jeopardy-bot/config.yml b/examples/jeopardy-bot/config.yml index 3609bd97e..3f5505d02 100644 --- a/examples/jeopardy-bot/config.yml +++ b/examples/jeopardy-bot/config.yml @@ -37,8 +37,7 @@ bf16: auto tf32: true resume_from_checkpoint: logging_steps: 5 -xformers_attention: true -flash_attention: +attention: xformers gptq_groupsize: gptq_model_v1: warmup_steps: 20 diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index 86b1b6a21..a49322f93 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -42,7 +42,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + flash_attn_cross_entropy: false flash_attn_rms_norm: true flash_attn_fuse_qkv: false diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index 0f1b34016..604891ba5 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -53,9 +53,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: -sdp_attention: -flash_optimum: +attention: flash warmup_steps: 100 evals_per_epoch: 4 saves_per_epoch: 1 diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml index a76a792ae..bfd86a4ac 100644 --- a/examples/llama-2/lisa.yml +++ b/examples/llama-2/lisa.yml @@ -46,7 +46,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + flash_attn_cross_entropy: false flash_attn_rms_norm: true flash_attn_fuse_qkv: false diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 22dbf2d99..db787e4ee 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -45,7 +45,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 679aed3a9..70e114202 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -45,7 +45,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index a42eabd4b..84ea167e8 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -48,7 +48,8 @@ gradient_checkpointing_kwargs: use_reentrant: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index de65928bc..3ae764b21 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -46,7 +46,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index e0a5f7068..bae72abf7 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -48,7 +48,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml index f4883e903..b1953d337 100644 --- a/examples/llama-3-vision/lora-11b.yaml +++ b/examples/llama-3-vision/lora-11b.yaml @@ -50,8 +50,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attention: flash warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index eccfa6d8c..51a8a4a87 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -49,7 +49,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 2 diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index fdae3e6c4..35ad3feb1 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -34,7 +34,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 2 diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index 13082294f..59b060535 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -61,7 +61,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index acab862f6..b2b6866e1 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -56,7 +56,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml index 10e9747cb..89399c0e1 100644 --- a/examples/llama-3/lora-1b-deduplicate-dpo.yml +++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml @@ -77,7 +77,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml index 630ec92f6..6f61b2c8f 100644 --- a/examples/llama-3/lora-1b-deduplicate-sft.yml +++ b/examples/llama-3/lora-1b-deduplicate-sft.yml @@ -53,7 +53,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml index a2d07ca49..13bde3dd7 100644 --- a/examples/llama-3/lora-1b-kernels.yml +++ b/examples/llama-3/lora-1b-kernels.yml @@ -54,7 +54,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml index bb23164eb..178bd5c25 100644 --- a/examples/llama-3/lora-1b-ray.yml +++ b/examples/llama-3/lora-1b-ray.yml @@ -48,7 +48,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml index 769dd32e6..d52d1c5d6 100644 --- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml +++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml @@ -55,7 +55,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml index c31a9f39a..d9a85190b 100644 --- a/examples/llama-3/lora-1b.yml +++ b/examples/llama-3/lora-1b.yml @@ -48,7 +48,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml index ad50cd38a..a1a096e29 100644 --- a/examples/llama-3/lora-8b.yml +++ b/examples/llama-3/lora-8b.yml @@ -49,7 +49,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index 89a51ea68..ee6f3215b 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -53,7 +53,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 20 evals_per_epoch: 4 diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml index 5c8fe6628..944581976 100644 --- a/examples/llama-3/qlora-1b.yml +++ b/examples/llama-3/qlora-1b.yml @@ -51,7 +51,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml index 2b7d51925..dec191513 100644 --- a/examples/llama-3/qlora-fsdp-405b.yaml +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -39,7 +39,8 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml index 412b6721c..80bcff581 100644 --- a/examples/llama-3/qlora-fsdp-70b.yaml +++ b/examples/llama-3/qlora-fsdp-70b.yaml @@ -48,7 +48,8 @@ gradient_checkpointing_kwargs: use_reentrant: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml index 4cc9fc3db..c7b5e97fa 100644 --- a/examples/llama-3/qlora.yml +++ b/examples/llama-3/qlora.yml @@ -46,7 +46,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml index 54edd04dc..38cd3e448 100644 --- a/examples/llava/lora-7b.yaml +++ b/examples/llava/lora-7b.yaml @@ -46,8 +46,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attention: flash warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 3d4583932..68466656b 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -39,7 +39,7 @@ tf32: true gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: +attention: eager warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/mistral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral-ds-zero3.yaml index f626a92a1..e2cb60967 100644 --- a/examples/mistral/bigstral-ds-zero3.yaml +++ b/examples/mistral/bigstral-ds-zero3.yaml @@ -42,7 +42,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + save_total_limit: 1 save_steps: diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index 15edffb44..1d3ea7886 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -36,7 +36,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/mistral/lora-mps.yml b/examples/mistral/lora-mps.yml index e6f46affb..e1fb5fa74 100644 --- a/examples/mistral/lora-mps.yml +++ b/examples/mistral/lora-mps.yml @@ -53,8 +53,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: false -sdp_attention: true +attention: sdpa loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml index 9af4274fd..70b6862e8 100644 --- a/examples/mistral/lora.yml +++ b/examples/mistral/lora.yml @@ -54,7 +54,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mistral-dpo-qlora.yml b/examples/mistral/mistral-dpo-qlora.yml index af707973f..e0e9b2fce 100644 --- a/examples/mistral/mistral-dpo-qlora.yml +++ b/examples/mistral/mistral-dpo-qlora.yml @@ -71,7 +71,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: false +attention: eager warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml index e234b19a2..96144f92f 100644 --- a/examples/mistral/mistral-qlora-fsdp.yml +++ b/examples/mistral/mistral-qlora-fsdp.yml @@ -51,7 +51,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mistral-qlora-orpo.yml b/examples/mistral/mistral-qlora-orpo.yml index 6c0212b7c..8d8c13883 100644 --- a/examples/mistral/mistral-qlora-orpo.yml +++ b/examples/mistral/mistral-qlora-orpo.yml @@ -59,7 +59,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mistral-small-3.1-24B-lora.yml b/examples/mistral/mistral-small-3.1-24B-lora.yml index 198b3f373..cfdc85425 100644 --- a/examples/mistral/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral/mistral-small-3.1-24B-lora.yml @@ -48,9 +48,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet. -eager_attention: - +attention: eager # PixtralVisionModel does not support Flash Attention 2.0 yet. warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 diff --git a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml index af6ba5a76..cd8963120 100644 --- a/examples/mistral/mixtral-8x22b-qlora-fsdp.yml +++ b/examples/mistral/mixtral-8x22b-qlora-fsdp.yml @@ -49,7 +49,8 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral-qlora-fsdp.yml index b1843a138..d5f522a45 100644 --- a/examples/mistral/mixtral-qlora-fsdp.yml +++ b/examples/mistral/mixtral-qlora-fsdp.yml @@ -51,7 +51,8 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index 4c256420c..6468e3ded 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -69,7 +69,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mixtral_22.yml b/examples/mistral/mixtral_22.yml index 25e1d7155..0313b8347 100644 --- a/examples/mistral/mixtral_22.yml +++ b/examples/mistral/mixtral_22.yml @@ -40,7 +40,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + save_total_limit: 1 save_steps: diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 607e33701..70c1076fc 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -54,7 +54,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mpt-7b/config.yml b/examples/mpt-7b/config.yml index e7485fad7..12a6db71d 100644 --- a/examples/mpt-7b/config.yml +++ b/examples/mpt-7b/config.yml @@ -39,7 +39,7 @@ bf16: auto tf32: true resume_from_checkpoint: logging_steps: 5 -flash_attention: +attention: eager gptq_groupsize: gptq_model_v1: warmup_steps: 20 diff --git a/examples/openllama-3b/config.yml b/examples/openllama-3b/config.yml index 17eeb73ae..915f52aaa 100644 --- a/examples/openllama-3b/config.yml +++ b/examples/openllama-3b/config.yml @@ -39,7 +39,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + gptq_groupsize: gptq_model_v1: warmup_steps: 20 diff --git a/examples/openllama-3b/lora.yml b/examples/openllama-3b/lora.yml index 073117f11..9217d8494 100644 --- a/examples/openllama-3b/lora.yml +++ b/examples/openllama-3b/lora.yml @@ -47,7 +47,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + gptq_groupsize: gptq_model_v1: warmup_steps: 20 diff --git a/examples/openllama-3b/qlora.yml b/examples/openllama-3b/qlora.yml index b4fca2c07..87d6bdbae 100644 --- a/examples/openllama-3b/qlora.yml +++ b/examples/openllama-3b/qlora.yml @@ -40,7 +40,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + gptq_groupsize: gptq_model_v1: warmup_steps: 20 diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index 1562a7353..75f4a16d5 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -48,7 +48,8 @@ gradient_checkpointing_kwargs: use_reentrant: True resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 4 diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index 4cd53db97..d8c16bb12 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -51,7 +51,8 @@ gradient_checkpointing_kwargs: use_reentrant: True resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 4 diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index ca733cc71..829350992 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -48,7 +48,8 @@ gradient_checkpointing_kwargs: use_reentrant: True resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 4 diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml index d0d14fea6..a3824951d 100644 --- a/examples/phi/phi3-ft-fsdp.yml +++ b/examples/phi/phi3-ft-fsdp.yml @@ -49,7 +49,8 @@ gradient_checkpointing_kwargs: use_reentrant: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 100 evals_per_epoch: 4 diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml index 17c48da6f..11a2ce957 100644 --- a/examples/phi/phi3-ft.yml +++ b/examples/phi/phi3-ft.yml @@ -44,7 +44,8 @@ gradient_checkpointing_kwargs: use_reentrant: True early_stopping_patience: 3 logging_steps: 1 -flash_attention: true +attention: flash + eval_steps: 1000 save_steps: 5000 diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index dec8e4b5e..12cc0d031 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -46,8 +46,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet -eager_attention: +attention: eager # PixtralVisionModel does not support Flash Attention 2.0 yet warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qwen/lora.yml b/examples/qwen/lora.yml index 9a2843236..0f4a2033f 100644 --- a/examples/qwen/lora.yml +++ b/examples/qwen/lora.yml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: +attention: eager warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/qwen/qlora.yml b/examples/qwen/qlora.yml index 5f85b44dd..4005a6c8b 100644 --- a/examples/qwen/qlora.yml +++ b/examples/qwen/qlora.yml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: +attention: flash warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/qwen/qwen2-moe-lora.yaml b/examples/qwen/qwen2-moe-lora.yaml index afce443a0..025cd2e60 100644 --- a/examples/qwen/qwen2-moe-lora.yaml +++ b/examples/qwen/qwen2-moe-lora.yaml @@ -43,7 +43,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/qwen/qwen2-moe-qlora.yaml b/examples/qwen/qwen2-moe-qlora.yaml index 92a6842cf..af4f0d3c4 100644 --- a/examples/qwen/qwen2-moe-qlora.yaml +++ b/examples/qwen/qwen2-moe-qlora.yaml @@ -46,7 +46,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml index 55773bc3d..6f20625be 100644 --- a/examples/qwen2-vl/lora-7b.yaml +++ b/examples/qwen2-vl/lora-7b.yaml @@ -46,8 +46,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attention: flash warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index 3547c6c98..842ff00a5 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -49,7 +49,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml index 4afa24f3c..ebbe4c75a 100644 --- a/examples/qwen2/prm.yaml +++ b/examples/qwen2/prm.yaml @@ -47,7 +47,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index ed2670ab6..1364e11bd 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -47,7 +47,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml index 822407a1f..1810cda1f 100644 --- a/examples/qwen2/reward-model.yaml +++ b/examples/qwen2/reward-model.yaml @@ -43,7 +43,8 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/redpajama/config-3b.yml b/examples/redpajama/config-3b.yml index 3e2999df9..ed26c4944 100644 --- a/examples/redpajama/config-3b.yml +++ b/examples/redpajama/config-3b.yml @@ -40,7 +40,7 @@ bf16: auto tf32: true resume_from_checkpoint: logging_steps: 5 -flash_attention: +attention: flash gptq_groupsize: gptq_model_v1: warmup_steps: 20 diff --git a/examples/replit-3b/config-lora.yml b/examples/replit-3b/config-lora.yml index 5a02ba10c..7ffa0bf98 100644 --- a/examples/replit-3b/config-lora.yml +++ b/examples/replit-3b/config-lora.yml @@ -38,7 +38,7 @@ tf32: true gradient_checkpointing: resume_from_checkpoint: logging_steps: 1 -flash_attention: +attention: flash gptq_groupsize: gptq_model_v1: warmup_steps: 20 diff --git a/examples/stablelm-2/1.6b/fft.yml b/examples/stablelm-2/1.6b/fft.yml index 9b45b399f..ca1092848 100644 --- a/examples/stablelm-2/1.6b/fft.yml +++ b/examples/stablelm-2/1.6b/fft.yml @@ -44,7 +44,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + flash_attn_cross_entropy: false flash_attn_rms_norm: true flash_attn_fuse_qkv: false diff --git a/examples/stablelm-2/1.6b/lora.yml b/examples/stablelm-2/1.6b/lora.yml index 31e5ad933..d25b225ab 100644 --- a/examples/stablelm-2/1.6b/lora.yml +++ b/examples/stablelm-2/1.6b/lora.yml @@ -47,7 +47,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + flash_attn_cross_entropy: false flash_attn_rms_norm: true diff --git a/examples/starcoder2/qlora.yml b/examples/starcoder2/qlora.yml index 18d85f9c3..302ee4433 100644 --- a/examples/starcoder2/qlora.yml +++ b/examples/starcoder2/qlora.yml @@ -46,7 +46,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 20 evals_per_epoch: 4 diff --git a/examples/tiny-llama/lora-mps.yml b/examples/tiny-llama/lora-mps.yml index 66cf7cfb3..314c45f0d 100644 --- a/examples/tiny-llama/lora-mps.yml +++ b/examples/tiny-llama/lora-mps.yml @@ -47,7 +47,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: false +attention: eager warmup_steps: 10 evals_per_epoch: 0 diff --git a/examples/tiny-llama/lora.yml b/examples/tiny-llama/lora.yml index 90998880f..7a869ab91 100644 --- a/examples/tiny-llama/lora.yml +++ b/examples/tiny-llama/lora.yml @@ -45,7 +45,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/tiny-llama/pretrain.yml b/examples/tiny-llama/pretrain.yml index 5b3706bcb..d05451b28 100644 --- a/examples/tiny-llama/pretrain.yml +++ b/examples/tiny-llama/pretrain.yml @@ -36,7 +36,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: diff --git a/examples/tiny-llama/qlora.yml b/examples/tiny-llama/qlora.yml index 8b2a4565a..038c4c2ce 100644 --- a/examples/tiny-llama/qlora.yml +++ b/examples/tiny-llama/qlora.yml @@ -47,7 +47,8 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attention: flash + warmup_steps: 10 evals_per_epoch: 4 diff --git a/examples/xgen-7b/xgen-7b-8k-qlora.yml b/examples/xgen-7b/xgen-7b-8k-qlora.yml index 48066b130..91b84aa27 100644 --- a/examples/xgen-7b/xgen-7b-8k-qlora.yml +++ b/examples/xgen-7b/xgen-7b-8k-qlora.yml @@ -71,8 +71,7 @@ early_stopping_patience: 3 resume_from_checkpoint: auto_resume_from_checkpoints: true logging_steps: 1 -xformers_attention: true -flash_attention: +attention: xformers gptq_groupsize: gptq_model_v1: warmup_steps: 10 diff --git a/examples/yi-34B-chat/qlora.yml b/examples/yi-34B-chat/qlora.yml index a0a95d86f..43fd56887 100644 --- a/examples/yi-34B-chat/qlora.yml +++ b/examples/yi-34B-chat/qlora.yml @@ -10,7 +10,8 @@ load_in_4bit: true sequence_len: 1024 bf16: auto tf32: false -flash_attention: true +attention: flash + special_tokens: bos_token: "<|startoftext|>" eos_token: "<|endoftext|>" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9db374409..6362bd2d7 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -27,7 +27,7 @@ from axolotl.utils.schemas.datasets import ( StepwiseSupervisedDataset, ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters -from axolotl.utils.schemas.enums import ChatTemplate, RLType +from axolotl.utils.schemas.enums import AttentionBackend, ChatTemplate, RLType from axolotl.utils.schemas.integrations import ( CometConfig, GradioConfig, @@ -222,6 +222,10 @@ class AxolotlInputConfig( }, ) + attention: AttentionBackend = Field( + default=AttentionBackend.flash, + json_schema_extra={"description": "attention backend to use"}, + ) xformers_attention: bool | None = None sdp_attention: bool | None = None s2_attention: bool | None = None @@ -436,6 +440,86 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def normalize_attn(cls, data): # pylint: disable=too-many-return-statements + # cases where both are set and already match + if data.get("attention") == AttentionBackend.eager and data.get( + "eager_attention" + ): + return data + if data.get("attention") == AttentionBackend.flash and data.get( + "flash_attention" + ): + return data + if data.get("attention") == AttentionBackend.s2 and data.get("s2_attention"): + return data + if data.get("attention") == AttentionBackend.sdpa and data.get("sdp_attention"): + return data + if data.get("attention") == AttentionBackend.xformers and data.get( + "xformers_attention" + ): + return data + + # cases where attention is set and the specific *_attention is not set + if not ( + data.get("flash_attention") + or data.get("eager_attention") + or data.get("s2_attention") + or data.get("sdp_attention") + or data.get("xformers_attention") + ): + if data.get("attention") == AttentionBackend.eager: + data["eager_attention"] = True + elif data.get("attention") == AttentionBackend.flash: + data["flash_attention"] = True + elif data.get("attention") == AttentionBackend.s2: + data["s2_attention"] = True + elif data.get("attention") == AttentionBackend.sdpa: + data["sdp_attention"] = True + elif data.get("attention") == AttentionBackend.xformers: + data["xformers_attention"] = True + return data + + # attention should always be set since that's a requirement, defaults to flash + if ( + data.get("eager_attention") + and not data.get("attention") == AttentionBackend.eager + ): + raise ValueError("attention mismatch with eager_attention already set") + if ( + data.get("flash_attention") + and not data.get("attention") == AttentionBackend.flash + ): + raise ValueError("attention mismatch with flash_attention already set") + if ( + data.get("s2_attention") + and not data.get("attention") == AttentionBackend.s2 + ): + raise ValueError("attention mismatch with s2_attention already set") + if ( + data.get("sdp_attention") + and not data.get("attention") == AttentionBackend.sdpa + ): + raise ValueError("attention mismatch with sdp_attention already set") + if ( + data.get("xformers_attention") + and not data.get("attention") == AttentionBackend.xformers + ): + raise ValueError("attention mismatch with xformers_attention already set") + + return data + + @model_validator(mode="before") + @classmethod + def check_sample_packing_w_xformers(cls, data): + if data.get("sample_packing") and data.get("xformers_attention"): + raise ValueError( + "sample_packing not compatible with xformers_attention. Use flash_attention" + ) + + return data + @model_validator(mode="before") @classmethod # pylint: disable=duplicate-code diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 118176d34..3a72f96c8 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -54,3 +54,14 @@ class CustomSupportedOptimizers(str, Enum): ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name + + +class AttentionBackend(str, Enum): + """Attention backend types""" + + eager = "eager" # pylint: disable=invalid-name + flash = "flash" # pylint: disable=invalid-name + flex = "flex" # pylint: disable=invalid-name + s2 = "s2" # pylint: disable=invalid-name + sdpa = "sdpa" # pylint: disable=invalid-name + xformers = "xformers" # pylint: disable=invalid-name