From e4032fc90f462fd9c2eefd2433c2f4fd0845641f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 5 May 2026 10:15:18 -0400 Subject: [PATCH] Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian Co-authored-by: Axolotl Swarm * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian Co-authored-by: Axolotl Swarm * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian --------- Signed-off-by: Wing Lian Co-authored-by: Axolotl Swarm --- docs/agents/new_model_support.md | 10 +- docs/agents/sft.md | 2 +- docs/attention.qmd | 184 +++++--- docs/ebft.qmd | 4 +- docs/grpo.qmd | 2 +- docs/optimizations.qmd | 10 +- docs/rlhf.qmd | 3 +- docs/sequence_parallelism.qmd | 2 +- docs/training_stability.qmd | 2 +- examples/LiquidAI/lfm2-350m-fft.yaml | 2 +- examples/LiquidAI/lfm2-8b-a1b-lora.yaml | 2 +- examples/LiquidAI/lfm2-vl-lora.yaml | 3 +- examples/alst/llama3-8b-deepspeed-alst.yaml | 2 +- examples/alst/llama3-8b-fsdp2-alst.yaml | 2 +- examples/apertus/apertus-8b-qlora.yaml | 2 +- examples/arcee/afm-4.5b-qlora.yaml | 2 +- examples/archived/cerebras/btlm-ft.yml | 3 +- examples/archived/cerebras/qlora.yml | 3 +- examples/archived/code-llama/13b/lora.yml | 2 +- examples/archived/code-llama/13b/qlora.yml | 2 +- examples/archived/code-llama/34b/lora.yml | 2 +- examples/archived/code-llama/34b/qlora.yml | 2 +- examples/archived/code-llama/7b/lora.yml | 2 +- examples/archived/code-llama/7b/qlora.yml | 2 +- examples/archived/dbrx/16bit-lora.yaml | 2 +- examples/archived/dbrx/8bit-lora.yaml | 2 +- examples/archived/dbrx/fft-ds-zero3.yaml | 2 +- .../deepcoder/deepcoder-14B-preview-lora.yml | 2 +- examples/archived/falcon/config-7b-lora.yml | 3 +- examples/archived/falcon/config-7b-qlora.yml | 3 +- examples/archived/falcon/config-7b.yml | 3 +- examples/archived/gemma/qlora.yml | 2 +- examples/archived/gptj/qlora.yml | 3 +- examples/archived/jeopardy-bot/config.yml | 3 +- examples/archived/mpt-7b/config.yml | 1 - examples/archived/openllama-3b/config.yml | 2 +- examples/archived/openllama-3b/lora.yml | 2 +- examples/archived/openllama-3b/qlora.yml | 2 +- examples/archived/qwen/lora.yml | 1 - examples/archived/qwen/qlora.yml | 1 - examples/archived/qwen/qwen2-moe-lora.yaml | 2 +- examples/archived/qwen/qwen2-moe-qlora.yaml | 2 +- examples/archived/redpajama/config-3b.yml | 1 - examples/archived/replit-3b/config-lora.yml | 1 - examples/archived/stablelm-2/1.6b/fft.yml | 2 +- examples/archived/stablelm-2/1.6b/lora.yml | 2 +- examples/archived/starcoder2/qlora.yml | 2 +- examples/archived/tiny-llama/lora-mps.yml | 1 - examples/archived/tiny-llama/lora.yml | 2 +- examples/archived/tiny-llama/pretrain.yml | 2 +- examples/archived/tiny-llama/qlora.yml | 2 +- .../archived/xgen-7b/xgen-7b-8k-qlora.yml | 3 +- examples/archived/yi-34B-chat/qlora.yml | 2 +- examples/cohere/command-r-7b-qlora.yml | 2 +- .../cogito-v1-preview-llama-3B-lora.yml | 2 +- .../cogito-v1-preview-qwen-14B-lora.yml | 2 +- examples/deepseek-v2/fft-fsdp-16b.yaml | 2 +- examples/deepseek-v2/qlora-fsdp-2_5.yaml | 2 +- examples/devstral/devstral-small-qlora.yml | 2 +- .../llama-3_1-8b-hsdp-tp.yaml | 2 +- .../qwen3-8b-fsdp-tp-cp.yaml | 2 +- examples/eaft/eaft-example.yml | 3 +- .../ebft/llama-1b-ebft-opencode-novllm.yaml | 2 +- examples/ebft/llama-1b-ebft-opencode.yaml | 2 +- .../llama-1b-ebft-strided-structured.yaml | 3 +- examples/ebft/llama-1b-ebft-strided.yaml | 1 - examples/ebft/llama-3b-ebft-strided-fft.yaml | 1 - examples/ebft/llama-8b-ebft-strided-fft.yaml | 1 - .../ebft/qwen35-4b-ebft-structured-async.yaml | 2 +- examples/ebft/qwen35-4b-ebft-structured.yaml | 2 +- examples/ebft/qwen35-9b-ebft-structured.yaml | 2 +- .../falcon-h1/falcon-h1-1b-deep-qlora.yaml | 2 +- examples/falcon-h1/falcon-h1-1b-qlora.yaml | 2 +- examples/falcon-h1/falcon-h1-34b-qlora.yaml | 2 +- examples/falcon-h1/falcon-h1-3b-qlora.yaml | 2 +- examples/falcon-h1/falcon-h1-500m-qlora.yaml | 2 +- examples/falcon-h1/falcon-h1-7b-qlora.yaml | 2 +- examples/gemma2/qlora.yml | 2 +- examples/gemma2/reward-model.yaml | 2 +- examples/gemma3/gemma-3-1b-qlora.yml | 2 +- examples/gemma3/gemma-3-270m-qlora.yml | 2 +- examples/gemma3/gemma-3-4b-qlora.yml | 3 +- examples/gemma3/gemma-3-4b-vision-qlora.yml | 3 +- examples/gemma4/26b-a4b-moe-qlora.yaml | 2 +- examples/gemma4/31b-qlora-flex.yaml | 2 +- examples/gemma4/31b-qlora.yaml | 2 +- examples/gemma4/e2b-vision-lora.yaml | 2 +- examples/glm4/qlora-32b.yaml | 2 +- examples/glm45/glm-45-air-qlora.yaml | 2 +- examples/glm46v/glm-4-6v-flash-ddp.yaml | 2 +- examples/glm46v/glm-4-6v-flash-qlora.yaml | 2 +- examples/glm47-flash/lora.yaml | 2 +- examples/glm47-flash/lora_fsdp.yaml | 2 +- examples/glm47-flash/qlora.yaml | 2 +- examples/glm47-flash/qlora_fsdp.yaml | 2 +- .../gpt-oss-120b-fft-fsdp2-offload.yaml | 1 - .../gpt-oss-20b-fft-deepspeed-zero3.yaml | 1 - .../gpt-oss-20b-fft-fsdp2-offload.yaml | 1 - examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml | 1 - .../gpt-oss-20b-sft-lora-singlegpu.yaml | 1 - ...-oss-safeguard-20b-sft-lora-singlegpu.yaml | 1 - examples/granite4/granite-4.0-tiny-fft.yaml | 2 +- examples/hunyuan/hunyuan-v1-dense-qlora.yaml | 2 +- examples/internvl3_5/internvl3_5-8b-qlora.yml | 3 +- examples/jamba/qlora.yaml | 2 +- examples/jamba/qlora_deepspeed.yaml | 2 +- examples/jamba/qlora_fsdp_large.yaml | 2 +- examples/kimi-linear/kimi-48b-lora.yaml | 2 +- examples/llama-2/fft_optimized.yml | 2 +- examples/llama-2/gptq-lora.yml | 2 - examples/llama-2/lisa.yml | 2 +- examples/llama-2/loftq.yml | 2 +- examples/llama-2/lora.yml | 2 +- examples/llama-2/qlora-fsdp.yml | 2 +- examples/llama-2/qlora.yml | 2 +- examples/llama-2/relora.yml | 2 +- examples/llama-3-vision/lora-11b.yaml | 2 +- examples/llama-3/3b-fp8-fsdp2.yaml | 2 +- examples/llama-3/3b-qat-fsdp2.yaml | 2 +- examples/llama-3/3b-qat-mxfp4.yaml | 2 +- examples/llama-3/3b-qat-nvfp4.yaml | 2 +- examples/llama-3/diffusion/pretrain-1b.yaml | 2 +- examples/llama-3/diffusion/sft-1b.yaml | 2 +- examples/llama-3/fft-8b-liger-fsdp.yaml | 2 +- examples/llama-3/fft-8b.yaml | 2 +- examples/llama-3/instruct-dpo-lora-8b.yml | 2 +- examples/llama-3/instruct-lora-8b.yml | 2 +- examples/llama-3/lora-1b-deduplicate-dpo.yml | 2 +- examples/llama-3/lora-1b-deduplicate-sft.yml | 2 +- examples/llama-3/lora-1b-kernels.yml | 2 +- examples/llama-3/lora-1b-ray.yml | 2 +- .../lora-1b-sample-packing-sequentially.yml | 2 +- examples/llama-3/lora-1b.yml | 2 +- examples/llama-3/lora-8b.yml | 2 +- examples/llama-3/opentelemetry-qlora.yml | 1 - examples/llama-3/qlora-1b-gdpo.yaml | 2 +- examples/llama-3/qlora-1b-kto.yaml | 2 +- examples/llama-3/qlora-1b.yml | 2 +- examples/llama-3/qlora-fsdp-405b.yaml | 2 +- examples/llama-3/qlora-fsdp-70b.yaml | 2 +- examples/llama-3/qlora.yml | 2 +- examples/llama-3/sparse-finetuning.yaml | 3 +- .../do-no-use-fa2/maverick-qlora-fsdp1.yaml | 2 +- .../do-no-use-fa2/scout-qlora-fsdp1.yaml | 2 +- .../scout-qlora-single-h100.yaml | 2 +- .../scout-vision-qlora-fsdp.yaml | 2 +- .../llama-4/scout-qlora-flexattn-fsdp2.yaml | 2 +- .../llama-4/scout-qlora-single-h100-flex.yaml | 2 +- .../scout-vision-qlora-fsdp2-flex.yaml | 2 +- examples/llava/lora-7b.yaml | 3 +- .../magistral/magistral-small-fsdp-qlora.yaml | 2 +- examples/magistral/magistral-small-qlora.yaml | 2 +- .../think/magistral-small-think-qlora.yaml | 2 +- .../magistral-small-vision-24B-qlora.yml | 2 +- examples/mamba/config.yml | 1 - examples/mimo/mimo-7b-qlora.yaml | 2 +- examples/ministral/ministral-small-qlora.yaml | 2 +- examples/ministral3/ministral3-3b-qlora.yaml | 2 +- .../think/ministral3-3b-think-qlora.yaml | 2 +- .../vision/ministral3-3b-vision-qlora.yml | 2 +- .../mistral-small-3.1-24B-lora.yml | 2 +- .../mistral/bigstral/bigstral-ds-zero3.yaml | 2 +- examples/mistral/config.yml | 2 +- examples/mistral/dpo/mistral-dpo-qlora.yml | 1 - examples/mistral/lora.yml | 2 +- examples/mistral/mistral-qlora-fsdp.yml | 2 +- .../mixtral/mixtral-8x22b-qlora-fsdp.yml | 2 +- .../mistral/mixtral/mixtral-qlora-fsdp.yml | 2 +- examples/mistral/mixtral/mixtral.yml | 2 +- examples/mistral/mixtral/mixtral_22.yml | 2 +- examples/mistral/mps/lora-mps.yml | 3 +- examples/mistral/orpo/mistral-qlora-orpo.yml | 2 +- examples/mistral/qlora.yml | 2 +- examples/mistral4/fft-text.yml | 2 +- examples/mistral4/fft-vision.yml | 2 +- examples/mistral4/qlora-text.yml | 2 +- examples/mistral4/qlora-vision.yml | 2 +- examples/nemotron-h/120b-a12b-qlora.yaml | 2 +- examples/nemotron-h/nano-30b-a3b-qlora.yaml | 2 +- examples/nemotron/nemotron-mini-4b-qlora.yaml | 2 +- examples/olmo3/olmo3-7b-qlora.yaml | 2 +- examples/orpheus/finetune.yml | 2 +- examples/phi/phi-ft.yml | 2 +- examples/phi/phi-qlora.yml | 2 +- examples/phi/phi2-ft.yml | 2 +- examples/phi/phi3-ft-fsdp.yml | 2 +- examples/phi/phi3-ft.yml | 2 +- examples/pixtral/lora-12b.yml | 2 +- examples/plano/plano-4b-qlora.yaml | 2 +- examples/qat_nvfp4/Gemma3-12B_baseline.yml | 2 +- examples/qat_nvfp4/Gemma3-12B_qat.yml | 2 +- .../qat_nvfp4/Math-Gemma3-12B_baseline.yml | 2 +- examples/qat_nvfp4/Math-Gemma3-12B_qat.yml | 2 +- .../qat_nvfp4/Math-Gemma3-27B_baseline.yml | 2 +- examples/qat_nvfp4/Math-Gemma3-27B_qat.yml | 2 +- .../qat_nvfp4/Math-Qwen2.5-72B_baseline.yml | 2 +- examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml | 2 +- examples/qat_nvfp4/Qwen2.5-72B_baseline.yml | 2 +- examples/qat_nvfp4/Qwen2.5-72B_qat.yml | 2 +- examples/qwen2-vl/lora-7b.yaml | 3 +- examples/qwen2/adamw-pretrain-fsdp2.yaml | 2 +- examples/qwen2/dpo.yaml | 2 +- examples/qwen2/muon-pretrain-fsdp2.yaml | 2 +- examples/qwen2/prm.yaml | 2 +- examples/qwen2/qlora-fsdp.yaml | 2 +- examples/qwen2/reward-model.yaml | 2 +- examples/qwen2_5-vl/lora-7b.yaml | 3 +- .../qwen3-next/qwen3-next-80b-a3b-qlora.yaml | 2 +- .../qwen3.5/122b-a10b-moe-qlora-fsdp.yaml | 2 +- examples/qwen3.5/122b-a10b-moe-qlora.yaml | 2 +- examples/qwen3.5/27b-fft.yaml | 2 +- examples/qwen3.5/27b-qlora-fsdp.yaml | 2 +- examples/qwen3.5/27b-qlora.yaml | 2 +- examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml | 2 +- examples/qwen3.5/35b-a3b-moe-qlora.yaml | 2 +- examples/qwen3.5/35b-a3b-moe-vision-lora.yaml | 2 +- examples/qwen3.5/9b-fft-vision.yaml | 2 +- examples/qwen3.5/9b-lora-vision.yaml | 2 +- examples/qwen3/32b-qlora.yaml | 2 +- examples/qwen3/8b-qat-fsdp2.yml | 2 +- examples/qwen3/qlora-fsdp.yaml | 2 +- examples/seed-oss/seed-oss-36b-qlora.yaml | 2 +- examples/smolvlm2/smolvlm2-2B-lora.yaml | 3 +- examples/streaming/pretrain.yaml | 2 +- examples/streaming/sft.yaml | 2 +- examples/swanlab/dpo-swanlab-completions.yml | 2 +- .../swanlab/dpo-swanlab-full-featured.yml | 2 +- examples/swanlab/lora-swanlab-profiling.yml | 2 +- .../trinity/trinity-nano-preview-qlora.yaml | 2 +- examples/voxtral/voxtral-mini-audio-qlora.yml | 2 +- examples/voxtral/voxtral-mini-qlora.yml | 2 +- src/axolotl/cli/merge_lora.py | 2 +- src/axolotl/core/builders/causal.py | 16 +- src/axolotl/integrations/lm_eval/__init__.py | 2 +- src/axolotl/integrations/lm_eval/cli.py | 10 +- src/axolotl/integrations/swanlab/plugins.py | 4 +- src/axolotl/loaders/model.py | 44 +- src/axolotl/loaders/patch_manager.py | 52 ++- src/axolotl/loaders/tokenizer.py | 6 +- src/axolotl/monkeypatch/attention/__init__.py | 26 ++ src/axolotl/monkeypatch/attention/fp8_attn.py | 30 ++ .../monkeypatch/attention/sage_attn.py | 18 +- src/axolotl/utils/callbacks/__init__.py | 5 +- src/axolotl/utils/collators/mm_chat.py | 20 +- src/axolotl/utils/schemas/config.py | 248 +++++++++-- src/axolotl/utils/schemas/enums.py | 62 +++ src/axolotl/utils/schemas/validation.py | 215 ++++----- src/axolotl/utils/trainer.py | 2 +- tests/e2e/multigpu/test_llama.py | 4 +- tests/test_attn_implementation.py | 418 ++++++++++++++++++ tests/test_mm_chat_collator.py | 163 +++++++ tests/test_no_legacy_attn_reads.py | 62 +++ 252 files changed, 1502 insertions(+), 572 deletions(-) create mode 100644 src/axolotl/monkeypatch/attention/fp8_attn.py create mode 100644 tests/test_attn_implementation.py create mode 100644 tests/test_mm_chat_collator.py create mode 100644 tests/test_no_legacy_attn_reads.py diff --git a/docs/agents/new_model_support.md b/docs/agents/new_model_support.md index 8e6028896..bc42ada86 100644 --- a/docs/agents/new_model_support.md +++ b/docs/agents/new_model_support.md @@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2 | Backend | Config | head_dim limit | torch_compile | Notes | |---------|--------|---------------|---------------|-------| -| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported | -| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ | -| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback | -| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims | -| eager | neither set | None | ✅ | Slowest, always works | +| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported | +| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ | +| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback | +| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims | +| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works | **Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class. diff --git a/docs/agents/sft.md b/docs/agents/sft.md index d3dfd39f7..f601cb0f5 100644 --- a/docs/agents/sft.md +++ b/docs/agents/sft.md @@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go | Issue | Fix | |-------|-----| | OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` | -| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` | +| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` | | Missing chat template error | Set `chat_template: chatml` explicitly | | Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels | | Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples | diff --git a/docs/attention.qmd b/docs/attention.qmd index b9644e074..f7fa5b456 100644 --- a/docs/attention.qmd +++ b/docs/attention.qmd @@ -3,28 +3,71 @@ title: Attention description: Supported attention modules in Axolotl --- -## SDP Attention - -This is the default built-in attention in PyTorch. +Axolotl routes attention via a single config field: ```yaml -sdp_attention: true +attn_implementation: ``` -For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +`attn_implementation` is passed through to `transformers` verbatim (via +`model.config._attn_implementation`). Accepted values are the HF-native +backends, axolotl-registered backends, or a hub-kernel path. -## Flash Attention +## Backends -Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically -based on your installed packages and GPU. +| `attn_implementation` | Description | +|---|---| +| `eager` | Plain PyTorch attention. No packing support. | +| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. | +| `flash_attention_2` | Dao-AILab Flash Attention 2. | +| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). | +| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). | +| `xformers` | xFormers memory-efficient attention. | +| `sage` | SageAttention (QK int8 / PV fp16). | +| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). | +| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. | +| `kernels-community/flash-attn3` | HF hub FA3 kernel. | +| `kernels-community/sage-attention` | HF hub SageAttention kernel. | +| Other `/` path | Any hub-kernel path supported by `transformers`. | + +Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** — +set the canonical name above. + +### Capability flags + +Axolotl derives three boolean capability flags from `attn_implementation` and +exposes them on the validated config: + +- `cfg.attn_supports_packing` — backend supports varlen sample packing via + `position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`. +- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab) + monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA). +- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings + (everything except `eager` and `sdpa`). + +These are **computed** — they cannot be overridden from YAML. + +## Per-backend notes + +### SDPA + +Default PyTorch attention. See +[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). ```yaml -flash_attention: true +attn_implementation: sdpa ``` -For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/) +### Flash Attention -### Flash Attention 2 +Axolotl supports FA2, FA3, and FA4. The best available version is used +automatically based on your installed packages and GPU. + +```yaml +attn_implementation: flash_attention_2 # or flash_attention_3 +``` + +#### Flash Attention 2 Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported) @@ -39,20 +82,20 @@ Alternatively, try reinstall or downgrade a version. ::: -### Flash Attention 3 +#### Flash Attention 3 Requirements: Hopper only and CUDA 12.8 (recommended) ```bash git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention/hopper - python setup.py install ``` -### Flash Attention 4 +#### Flash Attention 4 -Requirements: Hopper or Blackwell GPUs +Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib` +is true and FA4 is importable. FA4 is still a pre-release on PyPI, so `--pre` is required: @@ -65,7 +108,6 @@ Or from source: ```bash git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention/flash_attn/cute - pip install -e . # FA2's flash_attn package includes a cute/ stub that shadows FA4. @@ -88,93 +130,113 @@ and falls back to FA2/3. ::: -For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute) - ### AMD -Requirements: ROCm 6.0 and above. +Requirements: ROCm 6.0 and above. See +[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support). -See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support). - -## Flex Attention - -A flexible PyTorch API for attention used in combination with `torch.compile`. +### Flex Attention ```yaml -flex_attention: true - -# recommended -torch_compile: true +attn_implementation: flex_attention +torch_compile: true # recommended ``` -::: {.callout-note} +Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/). -We recommend using latest stable version of PyTorch for best performance. +### SageAttention -::: - -For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/) - -## SageAttention - -Attention kernels with QK Int8 and PV FP16 accumulator. +Requirements: Ampere, Ada, or Hopper GPUs. ```yaml -sage_attention: true +attn_implementation: sage ``` -Requirements: Ampere, Ada, or Hopper GPUs - ```bash pip install sageattention==2.2.0 --no-build-isolation ``` ::: {.callout-warning} -Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198). +Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See +[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198). ::: -For more details: [Sage Attention](https://github.com/thu-ml/SageAttention) +For more details: [Sage Attention](https://github.com/thu-ml/SageAttention). -::: {.callout-note} - -We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue. - -::: - - -## xFormers +### xFormers ```yaml -xformers_attention: true +attn_implementation: xformers ``` ::: {.callout-tip} -We recommend using with Turing GPUs or below (such as on Colab). +Recommended for Turing GPUs or below (e.g. Colab T4). ::: -For more details: [xFormers](https://github.com/facebookresearch/xformers) - -## Shifted Sparse Attention +### Shifted Sparse Attention ::: {.callout-warning} -We plan to deprecate this! If you use this feature, we recommend switching to methods above. +Planned for deprecation. Prefer one of the backends above. ::: -Requirements: LLaMA model architecture +Requirements: LLaMA model architecture. Loaded as FA2 under the hood and +patched to implement shifted-sparse attention. Does not support sample packing. ```yaml -flash_attention: true -s2_attention: true +attn_implementation: s2 ``` -::: {.callout-tip} +### FP8 -No sample packing support! +torchao low-precision attention. Loaded as SDPA and patched post-load. + +Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17, +flash-attn with FA3. KV caching must be disabled. + +```yaml +attn_implementation: fp8 +``` + +### Hub kernels + +```yaml +attn_implementation: kernels-community/flash-attn3 +``` + +Passed through to `transformers`; axolotl does not install the kernel itself. +For recognized hub paths the capability flags are set automatically; for +arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`, +`attn_uses_flash_lib=False`). + +## Migrating from legacy boolean flags + +The following legacy config fields are **deprecated** and will be removed in a +future release. Each emits a `DeprecationWarning` when set and is stripped from +the validated config. + +| Legacy | Canonical | +|---|---| +| `flash_attention: true` | `attn_implementation: flash_attention_2` | +| `sdp_attention: true` | `attn_implementation: sdpa` | +| `xformers_attention: true` | `attn_implementation: xformers` | +| `flex_attention: true` | `attn_implementation: flex_attention` | +| `sage_attention: true` | `attn_implementation: sage` | +| `s2_attention: true` | `attn_implementation: s2` | +| `eager_attention: true` | `attn_implementation: eager` | + +Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation: +flash_attention_2` **and** `flash_attention: true`) raises — pick one. + +::: {.callout-note} + +Existing example configs under `examples/` still use the legacy flags. They +continue to work with a deprecation warning; they will be migrated in a +follow-up pass. ::: diff --git a/docs/ebft.qmd b/docs/ebft.qmd index eb7c95eca..d9afc3307 100644 --- a/docs/ebft.qmd +++ b/docs/ebft.qmd @@ -129,7 +129,7 @@ gradient_accumulation_steps: 4 max_steps: 20 learning_rate: 5.0e-6 bf16: auto -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true output_dir: ./outputs/ebft-quickstart ``` @@ -304,7 +304,7 @@ lora_alpha: 32 lora_target_linear: true bf16: auto -flex_attention: true +attn_implementation: flex_attention gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true # Required with flex_attention diff --git a/docs/grpo.qmd b/docs/grpo.qmd index 35631f136..a98dbe11d 100644 --- a/docs/grpo.qmd +++ b/docs/grpo.qmd @@ -154,7 +154,7 @@ lr_scheduler: cosine warmup_steps: 10 bf16: true -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true special_tokens: diff --git a/docs/optimizations.qmd b/docs/optimizations.qmd index b180387ed..720519ec0 100644 --- a/docs/optimizations.qmd +++ b/docs/optimizations.qmd @@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac Using an optimized attention implementation is critical for training speed. -- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support). -- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`. -- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation. -- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16. +- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support). +- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`. +- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation. +- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16. -*Note: You should only enable one attention backend.* +See [Attention](attention.qmd) for the full list of backends and the canonical values. ### LoRA Optimizations diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 75d20414c..a27bb2966 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -1147,8 +1147,7 @@ datasets: type: ebft_strided_structured.transform split: train[:1%] -flash_attention: false -flex_attention: true # Strided mode uses flex_attention +attn_implementation: flex_attention # Strided mode uses flex_attention gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true # Required for flex_attention diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index d1933a145..9799c8a70 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -55,7 +55,7 @@ To use sequence parallelism, you need: ## Limitations -- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML) +- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML) - May have a small performance overhead due to communication between GPUs ## Example diff --git a/docs/training_stability.qmd b/docs/training_stability.qmd index e2cd79f89..9849a35d1 100644 --- a/docs/training_stability.qmd +++ b/docs/training_stability.qmd @@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with Reduces attention memory from O(n^2) to O(n): ```yaml -flash_attention: true +attn_implementation: flash_attention_2 ``` ### Step 6: Offload with DeepSpeed diff --git a/examples/LiquidAI/lfm2-350m-fft.yaml b/examples/LiquidAI/lfm2-350m-fft.yaml index 145b56dd1..cd5942206 100644 --- a/examples/LiquidAI/lfm2-350m-fft.yaml +++ b/examples/LiquidAI/lfm2-350m-fft.yaml @@ -39,7 +39,7 @@ tf32: true gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/LiquidAI/lfm2-8b-a1b-lora.yaml b/examples/LiquidAI/lfm2-8b-a1b-lora.yaml index 73cbfcce7..4932ea06e 100644 --- a/examples/LiquidAI/lfm2-8b-a1b-lora.yaml +++ b/examples/LiquidAI/lfm2-8b-a1b-lora.yaml @@ -48,7 +48,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/LiquidAI/lfm2-vl-lora.yaml b/examples/LiquidAI/lfm2-vl-lora.yaml index 313da8274..9a125da5e 100644 --- a/examples/LiquidAI/lfm2-vl-lora.yaml +++ b/examples/LiquidAI/lfm2-vl-lora.yaml @@ -50,8 +50,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/alst/llama3-8b-deepspeed-alst.yaml b/examples/alst/llama3-8b-deepspeed-alst.yaml index dea23c5ee..e844c6823 100644 --- a/examples/alst/llama3-8b-deepspeed-alst.yaml +++ b/examples/alst/llama3-8b-deepspeed-alst.yaml @@ -39,7 +39,7 @@ activation_offloading: legacy resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_steps: 100 saves_per_epoch: 1 diff --git a/examples/alst/llama3-8b-fsdp2-alst.yaml b/examples/alst/llama3-8b-fsdp2-alst.yaml index c8a978264..a7da92637 100644 --- a/examples/alst/llama3-8b-fsdp2-alst.yaml +++ b/examples/alst/llama3-8b-fsdp2-alst.yaml @@ -39,7 +39,7 @@ activation_offloading: legacy resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_steps: 100 saves_per_epoch: 1 diff --git a/examples/apertus/apertus-8b-qlora.yaml b/examples/apertus/apertus-8b-qlora.yaml index 521b282da..f43901363 100644 --- a/examples/apertus/apertus-8b-qlora.yaml +++ b/examples/apertus/apertus-8b-qlora.yaml @@ -55,7 +55,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/arcee/afm-4.5b-qlora.yaml b/examples/arcee/afm-4.5b-qlora.yaml index 2cb42cacd..8e70847ad 100644 --- a/examples/arcee/afm-4.5b-qlora.yaml +++ b/examples/arcee/afm-4.5b-qlora.yaml @@ -55,7 +55,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/archived/cerebras/btlm-ft.yml b/examples/archived/cerebras/btlm-ft.yml index c3495d287..5a5f8dc12 100644 --- a/examples/archived/cerebras/btlm-ft.yml +++ b/examples/archived/cerebras/btlm-ft.yml @@ -59,8 +59,7 @@ gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true -sdp_attention: +attn_implementation: flash_attention_2 flash_optimum: gptq_groupsize: diff --git a/examples/archived/cerebras/qlora.yml b/examples/archived/cerebras/qlora.yml index 4598a8338..22f52e682 100644 --- a/examples/archived/cerebras/qlora.yml +++ b/examples/archived/cerebras/qlora.yml @@ -39,8 +39,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attn_implementation: xformers gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/code-llama/13b/lora.yml b/examples/archived/code-llama/13b/lora.yml index ace94b619..43f623357 100644 --- a/examples/archived/code-llama/13b/lora.yml +++ b/examples/archived/code-llama/13b/lora.yml @@ -45,7 +45,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/code-llama/13b/qlora.yml b/examples/archived/code-llama/13b/qlora.yml index f4ed17af5..086f5e3d8 100644 --- a/examples/archived/code-llama/13b/qlora.yml +++ b/examples/archived/code-llama/13b/qlora.yml @@ -46,7 +46,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/code-llama/34b/lora.yml b/examples/archived/code-llama/34b/lora.yml index 0a1d71467..19aa898be 100644 --- a/examples/archived/code-llama/34b/lora.yml +++ b/examples/archived/code-llama/34b/lora.yml @@ -45,7 +45,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/code-llama/34b/qlora.yml b/examples/archived/code-llama/34b/qlora.yml index ec17bf200..2ec78f0d8 100644 --- a/examples/archived/code-llama/34b/qlora.yml +++ b/examples/archived/code-llama/34b/qlora.yml @@ -46,7 +46,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/code-llama/7b/lora.yml b/examples/archived/code-llama/7b/lora.yml index 174c17d2c..30bc63355 100644 --- a/examples/archived/code-llama/7b/lora.yml +++ b/examples/archived/code-llama/7b/lora.yml @@ -45,7 +45,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/code-llama/7b/qlora.yml b/examples/archived/code-llama/7b/qlora.yml index 08e67d8c2..0c3b38519 100644 --- a/examples/archived/code-llama/7b/qlora.yml +++ b/examples/archived/code-llama/7b/qlora.yml @@ -46,7 +46,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/dbrx/16bit-lora.yaml b/examples/archived/dbrx/16bit-lora.yaml index 05946dfe9..eca58f94c 100644 --- a/examples/archived/dbrx/16bit-lora.yaml +++ b/examples/archived/dbrx/16bit-lora.yaml @@ -52,7 +52,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/archived/dbrx/8bit-lora.yaml b/examples/archived/dbrx/8bit-lora.yaml index f159bf7fa..59f5241b4 100644 --- a/examples/archived/dbrx/8bit-lora.yaml +++ b/examples/archived/dbrx/8bit-lora.yaml @@ -55,7 +55,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/archived/dbrx/fft-ds-zero3.yaml b/examples/archived/dbrx/fft-ds-zero3.yaml index 13cd0d997..2cb3e6da1 100644 --- a/examples/archived/dbrx/fft-ds-zero3.yaml +++ b/examples/archived/dbrx/fft-ds-zero3.yaml @@ -39,7 +39,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/archived/deepcoder/deepcoder-14B-preview-lora.yml b/examples/archived/deepcoder/deepcoder-14B-preview-lora.yml index 3223ec19a..b125e9e3f 100644 --- a/examples/archived/deepcoder/deepcoder-14B-preview-lora.yml +++ b/examples/archived/deepcoder/deepcoder-14B-preview-lora.yml @@ -45,7 +45,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/archived/falcon/config-7b-lora.yml b/examples/archived/falcon/config-7b-lora.yml index f4fedbede..71dd572b3 100644 --- a/examples/archived/falcon/config-7b-lora.yml +++ b/examples/archived/falcon/config-7b-lora.yml @@ -43,8 +43,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attn_implementation: xformers gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/falcon/config-7b-qlora.yml b/examples/archived/falcon/config-7b-qlora.yml index a44cc40a6..edd6550a7 100644 --- a/examples/archived/falcon/config-7b-qlora.yml +++ b/examples/archived/falcon/config-7b-qlora.yml @@ -73,8 +73,7 @@ early_stopping_patience: 3 resume_from_checkpoint: auto_resume_from_checkpoints: true logging_steps: 1 -xformers_attention: true -flash_attention: +attn_implementation: xformers gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/falcon/config-7b.yml b/examples/archived/falcon/config-7b.yml index 5481fb236..6da39d7ab 100644 --- a/examples/archived/falcon/config-7b.yml +++ b/examples/archived/falcon/config-7b.yml @@ -40,8 +40,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attn_implementation: xformers gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/gemma/qlora.yml b/examples/archived/gemma/qlora.yml index 80829b3c9..5b5ec4a9f 100644 --- a/examples/archived/gemma/qlora.yml +++ b/examples/archived/gemma/qlora.yml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/gptj/qlora.yml b/examples/archived/gptj/qlora.yml index 6348566c2..7e10adeaa 100644 --- a/examples/archived/gptj/qlora.yml +++ b/examples/archived/gptj/qlora.yml @@ -36,8 +36,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -xformers_attention: true -flash_attention: +attn_implementation: xformers gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/jeopardy-bot/config.yml b/examples/archived/jeopardy-bot/config.yml index ab1d19784..90ca3b4bc 100644 --- a/examples/archived/jeopardy-bot/config.yml +++ b/examples/archived/jeopardy-bot/config.yml @@ -37,8 +37,7 @@ bf16: auto tf32: true resume_from_checkpoint: logging_steps: 5 -xformers_attention: true -flash_attention: +attn_implementation: xformers gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/mpt-7b/config.yml b/examples/archived/mpt-7b/config.yml index 1fff51b6e..588981bf7 100644 --- a/examples/archived/mpt-7b/config.yml +++ b/examples/archived/mpt-7b/config.yml @@ -39,7 +39,6 @@ bf16: auto tf32: true resume_from_checkpoint: logging_steps: 5 -flash_attention: gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/openllama-3b/config.yml b/examples/archived/openllama-3b/config.yml index 63056ed6d..14104ff4b 100644 --- a/examples/archived/openllama-3b/config.yml +++ b/examples/archived/openllama-3b/config.yml @@ -39,7 +39,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/openllama-3b/lora.yml b/examples/archived/openllama-3b/lora.yml index b70821ce2..30d3888f1 100644 --- a/examples/archived/openllama-3b/lora.yml +++ b/examples/archived/openllama-3b/lora.yml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/openllama-3b/qlora.yml b/examples/archived/openllama-3b/qlora.yml index a34f2964b..fc9d1d703 100644 --- a/examples/archived/openllama-3b/qlora.yml +++ b/examples/archived/openllama-3b/qlora.yml @@ -40,7 +40,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/qwen/lora.yml b/examples/archived/qwen/lora.yml index 29de25611..362a848a8 100644 --- a/examples/archived/qwen/lora.yml +++ b/examples/archived/qwen/lora.yml @@ -47,7 +47,6 @@ tf32: false gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/qwen/qlora.yml b/examples/archived/qwen/qlora.yml index d46669444..bce3012e7 100644 --- a/examples/archived/qwen/qlora.yml +++ b/examples/archived/qwen/qlora.yml @@ -47,7 +47,6 @@ tf32: false gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/qwen/qwen2-moe-lora.yaml b/examples/archived/qwen/qwen2-moe-lora.yaml index 1d5e1b524..97c0d51a6 100644 --- a/examples/archived/qwen/qwen2-moe-lora.yaml +++ b/examples/archived/qwen/qwen2-moe-lora.yaml @@ -43,7 +43,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/qwen/qwen2-moe-qlora.yaml b/examples/archived/qwen/qwen2-moe-qlora.yaml index 08731441b..a16089eed 100644 --- a/examples/archived/qwen/qwen2-moe-qlora.yaml +++ b/examples/archived/qwen/qwen2-moe-qlora.yaml @@ -46,7 +46,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/redpajama/config-3b.yml b/examples/archived/redpajama/config-3b.yml index c5b229c3d..676f31476 100644 --- a/examples/archived/redpajama/config-3b.yml +++ b/examples/archived/redpajama/config-3b.yml @@ -40,7 +40,6 @@ bf16: auto tf32: true resume_from_checkpoint: logging_steps: 5 -flash_attention: gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/replit-3b/config-lora.yml b/examples/archived/replit-3b/config-lora.yml index d8561762c..b0a0c9089 100644 --- a/examples/archived/replit-3b/config-lora.yml +++ b/examples/archived/replit-3b/config-lora.yml @@ -38,7 +38,6 @@ tf32: true gradient_checkpointing: resume_from_checkpoint: logging_steps: 1 -flash_attention: gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/stablelm-2/1.6b/fft.yml b/examples/archived/stablelm-2/1.6b/fft.yml index 585888f43..05f59544c 100644 --- a/examples/archived/stablelm-2/1.6b/fft.yml +++ b/examples/archived/stablelm-2/1.6b/fft.yml @@ -44,7 +44,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 flash_attn_cross_entropy: false flash_attn_rms_norm: true flash_attn_fuse_mlp: true diff --git a/examples/archived/stablelm-2/1.6b/lora.yml b/examples/archived/stablelm-2/1.6b/lora.yml index 6d358bdd8..1edb56e0c 100644 --- a/examples/archived/stablelm-2/1.6b/lora.yml +++ b/examples/archived/stablelm-2/1.6b/lora.yml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 flash_attn_cross_entropy: false flash_attn_rms_norm: true diff --git a/examples/archived/starcoder2/qlora.yml b/examples/archived/starcoder2/qlora.yml index fecf98d23..0fd0f453c 100644 --- a/examples/archived/starcoder2/qlora.yml +++ b/examples/archived/starcoder2/qlora.yml @@ -46,7 +46,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/tiny-llama/lora-mps.yml b/examples/archived/tiny-llama/lora-mps.yml index 125090a78..bf3292c35 100644 --- a/examples/archived/tiny-llama/lora-mps.yml +++ b/examples/archived/tiny-llama/lora-mps.yml @@ -47,7 +47,6 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: false warmup_ratio: 0.1 evals_per_epoch: 0 diff --git a/examples/archived/tiny-llama/lora.yml b/examples/archived/tiny-llama/lora.yml index 817481e18..a12d63746 100644 --- a/examples/archived/tiny-llama/lora.yml +++ b/examples/archived/tiny-llama/lora.yml @@ -45,7 +45,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/tiny-llama/pretrain.yml b/examples/archived/tiny-llama/pretrain.yml index f15c6ce19..4d1686138 100644 --- a/examples/archived/tiny-llama/pretrain.yml +++ b/examples/archived/tiny-llama/pretrain.yml @@ -36,7 +36,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/archived/tiny-llama/qlora.yml b/examples/archived/tiny-llama/qlora.yml index d3ff59cb8..b1adcb2e6 100644 --- a/examples/archived/tiny-llama/qlora.yml +++ b/examples/archived/tiny-llama/qlora.yml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/archived/xgen-7b/xgen-7b-8k-qlora.yml b/examples/archived/xgen-7b/xgen-7b-8k-qlora.yml index fc09a1e7b..d548032b9 100644 --- a/examples/archived/xgen-7b/xgen-7b-8k-qlora.yml +++ b/examples/archived/xgen-7b/xgen-7b-8k-qlora.yml @@ -71,8 +71,7 @@ early_stopping_patience: 3 resume_from_checkpoint: auto_resume_from_checkpoints: true logging_steps: 1 -xformers_attention: true -flash_attention: +attn_implementation: xformers gptq_groupsize: gptq_model_v1: warmup_ratio: 0.1 diff --git a/examples/archived/yi-34B-chat/qlora.yml b/examples/archived/yi-34B-chat/qlora.yml index ba8d12fc8..5d3d54dc6 100644 --- a/examples/archived/yi-34B-chat/qlora.yml +++ b/examples/archived/yi-34B-chat/qlora.yml @@ -10,7 +10,7 @@ load_in_4bit: true sequence_len: 1024 bf16: auto tf32: false -flash_attention: true +attn_implementation: flash_attention_2 special_tokens: bos_token: "<|startoftext|>" eos_token: "<|endoftext|>" diff --git a/examples/cohere/command-r-7b-qlora.yml b/examples/cohere/command-r-7b-qlora.yml index b4741636b..c4d03b0ec 100644 --- a/examples/cohere/command-r-7b-qlora.yml +++ b/examples/cohere/command-r-7b-qlora.yml @@ -48,7 +48,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml index 97d1bb6b3..c36b0e74a 100644 --- a/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-llama-3B-lora.yml @@ -45,7 +45,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml index b80cc5bc0..2b2aafd75 100644 --- a/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml +++ b/examples/deepcogito/cogito-v1-preview-qwen-14B-lora.yml @@ -45,7 +45,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/deepseek-v2/fft-fsdp-16b.yaml b/examples/deepseek-v2/fft-fsdp-16b.yaml index 6e936da16..2eac9aea3 100644 --- a/examples/deepseek-v2/fft-fsdp-16b.yaml +++ b/examples/deepseek-v2/fft-fsdp-16b.yaml @@ -35,7 +35,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/deepseek-v2/qlora-fsdp-2_5.yaml b/examples/deepseek-v2/qlora-fsdp-2_5.yaml index aab5034a0..0e23a0266 100644 --- a/examples/deepseek-v2/qlora-fsdp-2_5.yaml +++ b/examples/deepseek-v2/qlora-fsdp-2_5.yaml @@ -59,7 +59,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/devstral/devstral-small-qlora.yml b/examples/devstral/devstral-small-qlora.yml index 3eafb9219..6ee0e014d 100644 --- a/examples/devstral/devstral-small-qlora.yml +++ b/examples/devstral/devstral-small-qlora.yml @@ -50,7 +50,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 # scaling_softmax: true # needs flex_attention loss_watchdog_threshold: 5.0 diff --git a/examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml b/examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml index f10dc9bd2..a99a6bef8 100644 --- a/examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml +++ b/examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml @@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/ sequence_len: 2048 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 gradient_accumulation_steps: 1 micro_batch_size: 1 diff --git a/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml b/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml index 584a33f44..a12b524ed 100644 --- a/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml +++ b/examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml @@ -26,7 +26,7 @@ output_dir: ./outputs/ndp-out/ sequence_len: 8192 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 gradient_accumulation_steps: 1 micro_batch_size: 1 # must be 1 when using context parallel diff --git a/examples/eaft/eaft-example.yml b/examples/eaft/eaft-example.yml index fed4179d2..b4b13a14c 100644 --- a/examples/eaft/eaft-example.yml +++ b/examples/eaft/eaft-example.yml @@ -65,8 +65,7 @@ early_stopping_patience: resume_from_checkpoint: local_rank: logging_steps: 1 -xformers_attention: -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 weight_decay: 0.0 diff --git a/examples/ebft/llama-1b-ebft-opencode-novllm.yaml b/examples/ebft/llama-1b-ebft-opencode-novllm.yaml index 0891033f0..7d7edad33 100644 --- a/examples/ebft/llama-1b-ebft-opencode-novllm.yaml +++ b/examples/ebft/llama-1b-ebft-opencode-novllm.yaml @@ -46,7 +46,7 @@ lora_dropout: 0.05 lora_target_linear: true bf16: auto -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true special_tokens: diff --git a/examples/ebft/llama-1b-ebft-opencode.yaml b/examples/ebft/llama-1b-ebft-opencode.yaml index d0d1069d8..c77c36677 100644 --- a/examples/ebft/llama-1b-ebft-opencode.yaml +++ b/examples/ebft/llama-1b-ebft-opencode.yaml @@ -66,7 +66,7 @@ lora_target_linear: true # --- Hardware --- bf16: auto -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true special_tokens: diff --git a/examples/ebft/llama-1b-ebft-strided-structured.yaml b/examples/ebft/llama-1b-ebft-strided-structured.yaml index 8ba63b64b..02e89dea0 100644 --- a/examples/ebft/llama-1b-ebft-strided-structured.yaml +++ b/examples/ebft/llama-1b-ebft-strided-structured.yaml @@ -47,8 +47,7 @@ lora_dropout: 0.05 lora_target_linear: true bf16: auto -flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime -flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true +attn_implementation: flex_attention # (full-model compile conflicts with gradient checkpointing + flex_attention) gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/examples/ebft/llama-1b-ebft-strided.yaml b/examples/ebft/llama-1b-ebft-strided.yaml index c9519f160..e3cfe8040 100644 --- a/examples/ebft/llama-1b-ebft-strided.yaml +++ b/examples/ebft/llama-1b-ebft-strided.yaml @@ -46,7 +46,6 @@ lora_dropout: 0.05 lora_target_linear: true bf16: auto -flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime gradient_checkpointing: true special_tokens: diff --git a/examples/ebft/llama-3b-ebft-strided-fft.yaml b/examples/ebft/llama-3b-ebft-strided-fft.yaml index 5695efa40..e39d3bcfa 100644 --- a/examples/ebft/llama-3b-ebft-strided-fft.yaml +++ b/examples/ebft/llama-3b-ebft-strided-fft.yaml @@ -48,7 +48,6 @@ lora_target_linear: true bf16: auto torch_dtype: bfloat16 -flash_attention: false gradient_checkpointing: true torch_compile: true gradient_checkpointing_kwargs: diff --git a/examples/ebft/llama-8b-ebft-strided-fft.yaml b/examples/ebft/llama-8b-ebft-strided-fft.yaml index 8cf962849..caed98085 100644 --- a/examples/ebft/llama-8b-ebft-strided-fft.yaml +++ b/examples/ebft/llama-8b-ebft-strided-fft.yaml @@ -41,7 +41,6 @@ warmup_steps: 10 weight_decay: 0.01 bf16: auto -flash_attention: false # strided EBFT uses flex_attention at runtime gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false diff --git a/examples/ebft/qwen35-4b-ebft-structured-async.yaml b/examples/ebft/qwen35-4b-ebft-structured-async.yaml index 759a31730..daa77d6f6 100644 --- a/examples/ebft/qwen35-4b-ebft-structured-async.yaml +++ b/examples/ebft/qwen35-4b-ebft-structured-async.yaml @@ -72,7 +72,7 @@ lora_dropout: 0.0 lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" bf16: auto -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true special_tokens: diff --git a/examples/ebft/qwen35-4b-ebft-structured.yaml b/examples/ebft/qwen35-4b-ebft-structured.yaml index 9108e87e9..d1b2a72f2 100644 --- a/examples/ebft/qwen35-4b-ebft-structured.yaml +++ b/examples/ebft/qwen35-4b-ebft-structured.yaml @@ -63,7 +63,7 @@ lora_dropout: 0.0 lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" bf16: auto -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true special_tokens: diff --git a/examples/ebft/qwen35-9b-ebft-structured.yaml b/examples/ebft/qwen35-9b-ebft-structured.yaml index e79fb5fbf..ad3b8538e 100644 --- a/examples/ebft/qwen35-9b-ebft-structured.yaml +++ b/examples/ebft/qwen35-9b-ebft-structured.yaml @@ -68,7 +68,7 @@ lora_dropout: 0.0 lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj" bf16: auto -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: true special_tokens: diff --git a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml index 2473179f0..f59f0df5c 100644 --- a/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml @@ -62,7 +62,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/falcon-h1/falcon-h1-1b-qlora.yaml b/examples/falcon-h1/falcon-h1-1b-qlora.yaml index bfb7836ef..8c3eb080d 100644 --- a/examples/falcon-h1/falcon-h1-1b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-1b-qlora.yaml @@ -61,7 +61,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/falcon-h1/falcon-h1-34b-qlora.yaml b/examples/falcon-h1/falcon-h1-34b-qlora.yaml index 80a9d45b5..28e7de956 100644 --- a/examples/falcon-h1/falcon-h1-34b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-34b-qlora.yaml @@ -62,7 +62,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/falcon-h1/falcon-h1-3b-qlora.yaml b/examples/falcon-h1/falcon-h1-3b-qlora.yaml index 02be8ac5d..71b38e2f7 100644 --- a/examples/falcon-h1/falcon-h1-3b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-3b-qlora.yaml @@ -62,7 +62,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/falcon-h1/falcon-h1-500m-qlora.yaml b/examples/falcon-h1/falcon-h1-500m-qlora.yaml index b112d5d85..91602ae71 100644 --- a/examples/falcon-h1/falcon-h1-500m-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-500m-qlora.yaml @@ -62,7 +62,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/falcon-h1/falcon-h1-7b-qlora.yaml b/examples/falcon-h1/falcon-h1-7b-qlora.yaml index c5505873d..cc7e8f6cd 100644 --- a/examples/falcon-h1/falcon-h1-7b-qlora.yaml +++ b/examples/falcon-h1/falcon-h1-7b-qlora.yaml @@ -62,7 +62,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/gemma2/qlora.yml b/examples/gemma2/qlora.yml index 8a295a1f8..b2fca74da 100644 --- a/examples/gemma2/qlora.yml +++ b/examples/gemma2/qlora.yml @@ -53,7 +53,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/gemma2/reward-model.yaml b/examples/gemma2/reward-model.yaml index 67b1228b2..f48bff626 100644 --- a/examples/gemma2/reward-model.yaml +++ b/examples/gemma2/reward-model.yaml @@ -43,7 +43,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 4bcbf09f4..95b99a0da 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -62,7 +62,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/gemma3/gemma-3-270m-qlora.yml b/examples/gemma3/gemma-3-270m-qlora.yml index 1f247ab05..800a88a1b 100644 --- a/examples/gemma3/gemma-3-270m-qlora.yml +++ b/examples/gemma3/gemma-3-270m-qlora.yml @@ -62,7 +62,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 5d939da19..e7c43ddef 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -58,8 +58,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index a12e84bee..790d9543a 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -55,8 +55,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/gemma4/26b-a4b-moe-qlora.yaml b/examples/gemma4/26b-a4b-moe-qlora.yaml index e7bdb6f46..cdc70ef4a 100644 --- a/examples/gemma4/26b-a4b-moe-qlora.yaml +++ b/examples/gemma4/26b-a4b-moe-qlora.yaml @@ -84,7 +84,7 @@ activation_offloading: true logging_steps: 1 # FA2 not supported -sdp_attention: true +attn_implementation: sdpa warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/gemma4/31b-qlora-flex.yaml b/examples/gemma4/31b-qlora-flex.yaml index 8456c9c13..87221c515 100644 --- a/examples/gemma4/31b-qlora-flex.yaml +++ b/examples/gemma4/31b-qlora-flex.yaml @@ -62,7 +62,7 @@ activation_offloading: true logging_steps: 1 # FA not supported -flex_attention: true +attn_implementation: flex_attention warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/gemma4/31b-qlora.yaml b/examples/gemma4/31b-qlora.yaml index 42086a43c..4a633436e 100644 --- a/examples/gemma4/31b-qlora.yaml +++ b/examples/gemma4/31b-qlora.yaml @@ -60,7 +60,7 @@ activation_offloading: true logging_steps: 1 # FA not supported -sdp_attention: true +attn_implementation: sdpa warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/gemma4/e2b-vision-lora.yaml b/examples/gemma4/e2b-vision-lora.yaml index c779aaea5..ae90bc1cb 100644 --- a/examples/gemma4/e2b-vision-lora.yaml +++ b/examples/gemma4/e2b-vision-lora.yaml @@ -50,7 +50,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -sdp_attention: true +attn_implementation: sdpa warmup_ratio: 0.1 weight_decay: 0.0 diff --git a/examples/glm4/qlora-32b.yaml b/examples/glm4/qlora-32b.yaml index 832abde05..151820924 100644 --- a/examples/glm4/qlora-32b.yaml +++ b/examples/glm4/qlora-32b.yaml @@ -50,7 +50,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/glm45/glm-45-air-qlora.yaml b/examples/glm45/glm-45-air-qlora.yaml index accb8898f..5723d3c45 100644 --- a/examples/glm45/glm-45-air-qlora.yaml +++ b/examples/glm45/glm-45-air-qlora.yaml @@ -55,7 +55,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/glm46v/glm-4-6v-flash-ddp.yaml b/examples/glm46v/glm-4-6v-flash-ddp.yaml index c67ac5e28..274f041a3 100644 --- a/examples/glm46v/glm-4-6v-flash-ddp.yaml +++ b/examples/glm46v/glm-4-6v-flash-ddp.yaml @@ -45,7 +45,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -sdp_attention: true +attn_implementation: sdpa warmup_ratio: 0.1 evals_per_epoch: 0 diff --git a/examples/glm46v/glm-4-6v-flash-qlora.yaml b/examples/glm46v/glm-4-6v-flash-qlora.yaml index 287944ae8..9fe8d6e43 100644 --- a/examples/glm46v/glm-4-6v-flash-qlora.yaml +++ b/examples/glm46v/glm-4-6v-flash-qlora.yaml @@ -42,7 +42,7 @@ tf32: false gradient_checkpointing: true logging_steps: 1 -sdp_attention: true +attn_implementation: sdpa warmup_ratio: 0.1 evals_per_epoch: 0 diff --git a/examples/glm47-flash/lora.yaml b/examples/glm47-flash/lora.yaml index 2586babb7..5f3de36e9 100644 --- a/examples/glm47-flash/lora.yaml +++ b/examples/glm47-flash/lora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/glm47-flash/lora_fsdp.yaml b/examples/glm47-flash/lora_fsdp.yaml index bee20bf02..cf1d2de55 100644 --- a/examples/glm47-flash/lora_fsdp.yaml +++ b/examples/glm47-flash/lora_fsdp.yaml @@ -57,7 +57,7 @@ tf32: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/glm47-flash/qlora.yaml b/examples/glm47-flash/qlora.yaml index 834c46af8..a05bf54d2 100644 --- a/examples/glm47-flash/qlora.yaml +++ b/examples/glm47-flash/qlora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/glm47-flash/qlora_fsdp.yaml b/examples/glm47-flash/qlora_fsdp.yaml index 0bb87813f..9ad5a6212 100644 --- a/examples/glm47-flash/qlora_fsdp.yaml +++ b/examples/glm47-flash/qlora_fsdp.yaml @@ -57,7 +57,7 @@ tf32: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml index b7082f986..71692958f 100644 --- a/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml @@ -47,7 +47,6 @@ learning_rate: 2e-5 bf16: true tf32: true -flash_attention: true attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml index b718ff2eb..5912f876b 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml @@ -43,7 +43,6 @@ learning_rate: 2e-5 bf16: true tf32: true -flash_attention: true attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml index af1c93bc0..b1a0fef4a 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml @@ -44,7 +44,6 @@ learning_rate: 2e-5 bf16: true tf32: true -flash_attention: true attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml index 894ba99b8..f97174cd9 100644 --- a/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml +++ b/examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml @@ -43,7 +43,6 @@ learning_rate: 2e-5 bf16: true tf32: true -flash_attention: true attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml index 7c4f97846..122fb0b6c 100644 --- a/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml @@ -56,7 +56,6 @@ learning_rate: 2e-4 bf16: true tf32: true -flash_attention: true attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml index cbb9efc8e..7ba5f29b5 100644 --- a/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml +++ b/examples/gpt-oss/gpt-oss-safeguard-20b-sft-lora-singlegpu.yaml @@ -56,7 +56,6 @@ learning_rate: 2e-4 bf16: true tf32: true -flash_attention: true attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3 gradient_checkpointing: true diff --git a/examples/granite4/granite-4.0-tiny-fft.yaml b/examples/granite4/granite-4.0-tiny-fft.yaml index 7ff8207ae..fd7d2a312 100644 --- a/examples/granite4/granite-4.0-tiny-fft.yaml +++ b/examples/granite4/granite-4.0-tiny-fft.yaml @@ -36,7 +36,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/hunyuan/hunyuan-v1-dense-qlora.yaml b/examples/hunyuan/hunyuan-v1-dense-qlora.yaml index a94345a61..1ae6b000d 100644 --- a/examples/hunyuan/hunyuan-v1-dense-qlora.yaml +++ b/examples/hunyuan/hunyuan-v1-dense-qlora.yaml @@ -55,7 +55,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/internvl3_5/internvl3_5-8b-qlora.yml b/examples/internvl3_5/internvl3_5-8b-qlora.yml index 9a72d078a..2d924c6f1 100644 --- a/examples/internvl3_5/internvl3_5-8b-qlora.yml +++ b/examples/internvl3_5/internvl3_5-8b-qlora.yml @@ -50,8 +50,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/jamba/qlora.yaml b/examples/jamba/qlora.yaml index 538ed3a10..f625fb6f5 100644 --- a/examples/jamba/qlora.yaml +++ b/examples/jamba/qlora.yaml @@ -47,7 +47,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/jamba/qlora_deepspeed.yaml b/examples/jamba/qlora_deepspeed.yaml index b288635e7..8ec74f905 100644 --- a/examples/jamba/qlora_deepspeed.yaml +++ b/examples/jamba/qlora_deepspeed.yaml @@ -46,7 +46,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/jamba/qlora_fsdp_large.yaml b/examples/jamba/qlora_fsdp_large.yaml index 4db889fbc..76cc0ef18 100644 --- a/examples/jamba/qlora_fsdp_large.yaml +++ b/examples/jamba/qlora_fsdp_large.yaml @@ -44,7 +44,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/kimi-linear/kimi-48b-lora.yaml b/examples/kimi-linear/kimi-48b-lora.yaml index 8e855dd72..befa29891 100644 --- a/examples/kimi-linear/kimi-48b-lora.yaml +++ b/examples/kimi-linear/kimi-48b-lora.yaml @@ -65,7 +65,7 @@ early_stopping_patience: resume_from_checkpoint: local_rank: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index ea119348e..7af25dd17 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -42,7 +42,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 flash_attn_cross_entropy: false flash_attn_rms_norm: true flash_attn_fuse_mlp: true diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index de1caaa05..c4073b80a 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -53,8 +53,6 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: -sdp_attention: flash_optimum: warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-2/lisa.yml b/examples/llama-2/lisa.yml index d21c01a49..40ba6d0d0 100644 --- a/examples/llama-2/lisa.yml +++ b/examples/llama-2/lisa.yml @@ -46,7 +46,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 flash_attn_cross_entropy: false flash_attn_rms_norm: true flash_attn_fuse_mlp: true diff --git a/examples/llama-2/loftq.yml b/examples/llama-2/loftq.yml index 619e5bcce..f1562ec29 100644 --- a/examples/llama-2/loftq.yml +++ b/examples/llama-2/loftq.yml @@ -45,7 +45,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 0a677f11a..8c2242b71 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -45,7 +45,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-2/qlora-fsdp.yml b/examples/llama-2/qlora-fsdp.yml index 1e7064de8..102eb7af7 100644 --- a/examples/llama-2/qlora-fsdp.yml +++ b/examples/llama-2/qlora-fsdp.yml @@ -48,7 +48,7 @@ gradient_checkpointing_kwargs: use_reentrant: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 327d88c15..87e710792 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -46,7 +46,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index fabdf0e0f..8e3df58bf 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -51,7 +51,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3-vision/lora-11b.yaml b/examples/llama-3-vision/lora-11b.yaml index adbb61643..4e5eb4c4e 100644 --- a/examples/llama-3-vision/lora-11b.yaml +++ b/examples/llama-3-vision/lora-11b.yaml @@ -50,7 +50,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 # flash_attention: true # use for text-only mode -sdp_attention: true +attn_implementation: sdpa warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/llama-3/3b-fp8-fsdp2.yaml b/examples/llama-3/3b-fp8-fsdp2.yaml index 57b308abd..cfc15870f 100644 --- a/examples/llama-3/3b-fp8-fsdp2.yaml +++ b/examples/llama-3/3b-fp8-fsdp2.yaml @@ -25,7 +25,7 @@ sample_packing: true pad_to_sequence_len: true sequence_len: 512 -flex_attention: true +attn_implementation: flex_attention flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs diff --git a/examples/llama-3/3b-qat-fsdp2.yaml b/examples/llama-3/3b-qat-fsdp2.yaml index 0c5a87891..99c975351 100644 --- a/examples/llama-3/3b-qat-fsdp2.yaml +++ b/examples/llama-3/3b-qat-fsdp2.yaml @@ -26,7 +26,7 @@ dataset_prepared_path: ./outputs/qat_out/dataset_prepared sample_packing: false sequence_len: 8192 -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: int8 diff --git a/examples/llama-3/3b-qat-mxfp4.yaml b/examples/llama-3/3b-qat-mxfp4.yaml index 7ae941e9e..4e9f64685 100644 --- a/examples/llama-3/3b-qat-mxfp4.yaml +++ b/examples/llama-3/3b-qat-mxfp4.yaml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out/ dataset_prepared_path: ./outputs/dataset_prepared sequence_len: 2048 -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: mxfp4 diff --git a/examples/llama-3/3b-qat-nvfp4.yaml b/examples/llama-3/3b-qat-nvfp4.yaml index 1ec809bbe..77cf2b19b 100644 --- a/examples/llama-3/3b-qat-nvfp4.yaml +++ b/examples/llama-3/3b-qat-nvfp4.yaml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out/ dataset_prepared_path: ./outputs/dataset_prepared sequence_len: 8192 -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: nvfp4 diff --git a/examples/llama-3/diffusion/pretrain-1b.yaml b/examples/llama-3/diffusion/pretrain-1b.yaml index 8d05e4c60..1b488db7a 100644 --- a/examples/llama-3/diffusion/pretrain-1b.yaml +++ b/examples/llama-3/diffusion/pretrain-1b.yaml @@ -35,7 +35,7 @@ warmup_ratio: 0.1 optimizer: adamw_8bit lr_scheduler: cosine learning_rate: 3e-4 -sdp_attention: true +attn_implementation: sdpa bf16: auto tf32: true diff --git a/examples/llama-3/diffusion/sft-1b.yaml b/examples/llama-3/diffusion/sft-1b.yaml index f3b29a809..b6de76af3 100644 --- a/examples/llama-3/diffusion/sft-1b.yaml +++ b/examples/llama-3/diffusion/sft-1b.yaml @@ -41,7 +41,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: -sdp_attention: true +attn_implementation: sdpa logging_steps: 1 save_strategy: best diff --git a/examples/llama-3/fft-8b-liger-fsdp.yaml b/examples/llama-3/fft-8b-liger-fsdp.yaml index a655b97a9..b96bc920e 100644 --- a/examples/llama-3/fft-8b-liger-fsdp.yaml +++ b/examples/llama-3/fft-8b-liger-fsdp.yaml @@ -49,7 +49,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/llama-3/fft-8b.yaml b/examples/llama-3/fft-8b.yaml index c72ec6662..3e2809196 100644 --- a/examples/llama-3/fft-8b.yaml +++ b/examples/llama-3/fft-8b.yaml @@ -34,7 +34,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/llama-3/instruct-dpo-lora-8b.yml b/examples/llama-3/instruct-dpo-lora-8b.yml index cf823353b..b49ace2ed 100644 --- a/examples/llama-3/instruct-dpo-lora-8b.yml +++ b/examples/llama-3/instruct-dpo-lora-8b.yml @@ -65,7 +65,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/instruct-lora-8b.yml b/examples/llama-3/instruct-lora-8b.yml index 401df1d72..1c61ce9e4 100644 --- a/examples/llama-3/instruct-lora-8b.yml +++ b/examples/llama-3/instruct-lora-8b.yml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b-deduplicate-dpo.yml b/examples/llama-3/lora-1b-deduplicate-dpo.yml index 2897636f4..2be72c4d0 100644 --- a/examples/llama-3/lora-1b-deduplicate-dpo.yml +++ b/examples/llama-3/lora-1b-deduplicate-dpo.yml @@ -77,7 +77,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b-deduplicate-sft.yml b/examples/llama-3/lora-1b-deduplicate-sft.yml index c5190d892..ad21cb266 100644 --- a/examples/llama-3/lora-1b-deduplicate-sft.yml +++ b/examples/llama-3/lora-1b-deduplicate-sft.yml @@ -53,7 +53,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b-kernels.yml b/examples/llama-3/lora-1b-kernels.yml index 0bcf46b17..b0914f87a 100644 --- a/examples/llama-3/lora-1b-kernels.yml +++ b/examples/llama-3/lora-1b-kernels.yml @@ -54,7 +54,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/lora-1b-ray.yml b/examples/llama-3/lora-1b-ray.yml index 46c83348e..a3aa1cf5e 100644 --- a/examples/llama-3/lora-1b-ray.yml +++ b/examples/llama-3/lora-1b-ray.yml @@ -48,7 +48,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/lora-1b-sample-packing-sequentially.yml b/examples/llama-3/lora-1b-sample-packing-sequentially.yml index dba78597b..f6c24bc74 100644 --- a/examples/llama-3/lora-1b-sample-packing-sequentially.yml +++ b/examples/llama-3/lora-1b-sample-packing-sequentially.yml @@ -55,7 +55,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/lora-1b.yml b/examples/llama-3/lora-1b.yml index 2ae2f0056..d01c618bc 100644 --- a/examples/llama-3/lora-1b.yml +++ b/examples/llama-3/lora-1b.yml @@ -49,7 +49,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/lora-8b.yml b/examples/llama-3/lora-8b.yml index d72b6527d..90084ec95 100644 --- a/examples/llama-3/lora-8b.yml +++ b/examples/llama-3/lora-8b.yml @@ -49,7 +49,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/opentelemetry-qlora.yml b/examples/llama-3/opentelemetry-qlora.yml index d8ce7b1ec..0c9995dae 100644 --- a/examples/llama-3/opentelemetry-qlora.yml +++ b/examples/llama-3/opentelemetry-qlora.yml @@ -39,7 +39,6 @@ tf32: false gradient_checkpointing: true logging_steps: 1 -flash_attention: false warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/llama-3/qlora-1b-gdpo.yaml b/examples/llama-3/qlora-1b-gdpo.yaml index d806fcf26..f754a6887 100644 --- a/examples/llama-3/qlora-1b-gdpo.yaml +++ b/examples/llama-3/qlora-1b-gdpo.yaml @@ -56,7 +56,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false -flash_attention: true +attn_implementation: flash_attention_2 logging_steps: 1 save_steps: 50 save_safetensors: true diff --git a/examples/llama-3/qlora-1b-kto.yaml b/examples/llama-3/qlora-1b-kto.yaml index a6a84e7b1..18c240d97 100644 --- a/examples/llama-3/qlora-1b-kto.yaml +++ b/examples/llama-3/qlora-1b-kto.yaml @@ -53,7 +53,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/qlora-1b.yml b/examples/llama-3/qlora-1b.yml index 1e4f97438..d1e5e18ae 100644 --- a/examples/llama-3/qlora-1b.yml +++ b/examples/llama-3/qlora-1b.yml @@ -51,7 +51,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/llama-3/qlora-fsdp-405b.yaml b/examples/llama-3/qlora-fsdp-405b.yaml index 5c236f2cf..b801af845 100644 --- a/examples/llama-3/qlora-fsdp-405b.yaml +++ b/examples/llama-3/qlora-fsdp-405b.yaml @@ -38,7 +38,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/qlora-fsdp-70b.yaml b/examples/llama-3/qlora-fsdp-70b.yaml index c052bc19d..5ce774e18 100644 --- a/examples/llama-3/qlora-fsdp-70b.yaml +++ b/examples/llama-3/qlora-fsdp-70b.yaml @@ -48,7 +48,7 @@ gradient_checkpointing_kwargs: use_reentrant: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/qlora.yml b/examples/llama-3/qlora.yml index a8f47a0e2..fad507cd9 100644 --- a/examples/llama-3/qlora.yml +++ b/examples/llama-3/qlora.yml @@ -46,7 +46,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/llama-3/sparse-finetuning.yaml b/examples/llama-3/sparse-finetuning.yaml index 348756b70..0ce4aa03d 100644 --- a/examples/llama-3/sparse-finetuning.yaml +++ b/examples/llama-3/sparse-finetuning.yaml @@ -44,8 +44,7 @@ gradient_checkpointing_kwargs: early_stopping_patience: resume_from_checkpoint: logging_steps: 1 -xformers_attention: -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml index b20f79758..2c701a2aa 100644 --- a/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/maverick-qlora-fsdp1.yaml @@ -60,7 +60,7 @@ bf16: true tf32: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: offload gradient_checkpointing_kwargs: diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml index 40449009c..8197d1629 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-fsdp1.yaml @@ -67,7 +67,7 @@ bf16: true tf32: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml index abdc51378..2dcff36cd 100644 --- a/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-qlora-single-h100.yaml @@ -70,7 +70,7 @@ bf16: true tf32: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 gradient_checkpointing: offload gradient_checkpointing_kwargs: diff --git a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml index 4136dc14a..de7ae5f50 100644 --- a/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml +++ b/examples/llama-4/do-no-use-fa2/scout-vision-qlora-fsdp.yaml @@ -62,7 +62,7 @@ bf16: true tf32: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml index 02c04c691..c5343fa2e 100644 --- a/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml +++ b/examples/llama-4/scout-qlora-flexattn-fsdp2.yaml @@ -59,7 +59,7 @@ bf16: true tf32: true logging_steps: 1 -flex_attention: true +attn_implementation: flex_attention flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs diff --git a/examples/llama-4/scout-qlora-single-h100-flex.yaml b/examples/llama-4/scout-qlora-single-h100-flex.yaml index 33a691189..00491c3b1 100644 --- a/examples/llama-4/scout-qlora-single-h100-flex.yaml +++ b/examples/llama-4/scout-qlora-single-h100-flex.yaml @@ -64,7 +64,7 @@ bf16: true tf32: true torch_compile: true -flex_attention: true +attn_implementation: flex_attention flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs diff --git a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml index 5972c2ae3..9b3e089b5 100644 --- a/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml +++ b/examples/llama-4/scout-vision-qlora-fsdp2-flex.yaml @@ -61,7 +61,7 @@ bf16: true tf32: true logging_steps: 1 -flex_attention: true +attn_implementation: flex_attention flex_attn_compile_kwargs: dynamic: false mode: max-autotune-no-cudagraphs diff --git a/examples/llava/lora-7b.yaml b/examples/llava/lora-7b.yaml index 77ef7474d..56b48fda9 100644 --- a/examples/llava/lora-7b.yaml +++ b/examples/llava/lora-7b.yaml @@ -45,8 +45,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/magistral/magistral-small-fsdp-qlora.yaml b/examples/magistral/magistral-small-fsdp-qlora.yaml index d46c49fe0..f31ca7326 100644 --- a/examples/magistral/magistral-small-fsdp-qlora.yaml +++ b/examples/magistral/magistral-small-fsdp-qlora.yaml @@ -59,7 +59,7 @@ tf32: false gradient_checkpointing: resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/magistral/magistral-small-qlora.yaml b/examples/magistral/magistral-small-qlora.yaml index 188924d39..90f6b6f91 100644 --- a/examples/magistral/magistral-small-qlora.yaml +++ b/examples/magistral/magistral-small-qlora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/magistral/think/magistral-small-think-qlora.yaml b/examples/magistral/think/magistral-small-think-qlora.yaml index b715b3156..85abe18da 100644 --- a/examples/magistral/think/magistral-small-think-qlora.yaml +++ b/examples/magistral/think/magistral-small-think-qlora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/magistral/vision/magistral-small-vision-24B-qlora.yml b/examples/magistral/vision/magistral-small-vision-24B-qlora.yml index 397db383e..abd244647 100644 --- a/examples/magistral/vision/magistral-small-vision-24B-qlora.yml +++ b/examples/magistral/vision/magistral-small-vision-24B-qlora.yml @@ -53,7 +53,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mamba/config.yml b/examples/mamba/config.yml index 5f36595a3..0c39768d8 100644 --- a/examples/mamba/config.yml +++ b/examples/mamba/config.yml @@ -39,7 +39,6 @@ tf32: true gradient_checkpointing: false resume_from_checkpoint: logging_steps: 1 -flash_attention: warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/mimo/mimo-7b-qlora.yaml b/examples/mimo/mimo-7b-qlora.yaml index 689213bcd..7ced584e1 100644 --- a/examples/mimo/mimo-7b-qlora.yaml +++ b/examples/mimo/mimo-7b-qlora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/ministral/ministral-small-qlora.yaml b/examples/ministral/ministral-small-qlora.yaml index 0d5300ef6..4c3bdfe94 100644 --- a/examples/ministral/ministral-small-qlora.yaml +++ b/examples/ministral/ministral-small-qlora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/ministral3/ministral3-3b-qlora.yaml b/examples/ministral3/ministral3-3b-qlora.yaml index 4efe5bd2f..985f2fad9 100644 --- a/examples/ministral3/ministral3-3b-qlora.yaml +++ b/examples/ministral3/ministral3-3b-qlora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 # scaling_softmax: true # needs flex_attention warmup_ratio: 0.1 diff --git a/examples/ministral3/think/ministral3-3b-think-qlora.yaml b/examples/ministral3/think/ministral3-3b-think-qlora.yaml index 987c0bd54..508575cac 100644 --- a/examples/ministral3/think/ministral3-3b-think-qlora.yaml +++ b/examples/ministral3/think/ministral3-3b-think-qlora.yaml @@ -58,7 +58,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/ministral3/vision/ministral3-3b-vision-qlora.yml b/examples/ministral3/vision/ministral3-3b-vision-qlora.yml index 0a0fdce4a..f1430ba53 100644 --- a/examples/ministral3/vision/ministral3-3b-vision-qlora.yml +++ b/examples/ministral3/vision/ministral3-3b-vision-qlora.yml @@ -53,7 +53,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mistral-small/mistral-small-3.1-24B-lora.yml b/examples/mistral-small/mistral-small-3.1-24B-lora.yml index d45d13ac6..4d3f78a13 100644 --- a/examples/mistral-small/mistral-small-3.1-24B-lora.yml +++ b/examples/mistral-small/mistral-small-3.1-24B-lora.yml @@ -51,7 +51,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mistral/bigstral/bigstral-ds-zero3.yaml b/examples/mistral/bigstral/bigstral-ds-zero3.yaml index a8dc36216..4648ae4b4 100644 --- a/examples/mistral/bigstral/bigstral-ds-zero3.yaml +++ b/examples/mistral/bigstral/bigstral-ds-zero3.yaml @@ -42,7 +42,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 save_total_limit: 1 save_steps: diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index e74162537..aa1066733 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -36,7 +36,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/mistral/dpo/mistral-dpo-qlora.yml b/examples/mistral/dpo/mistral-dpo-qlora.yml index 8fea14a0f..604eada74 100644 --- a/examples/mistral/dpo/mistral-dpo-qlora.yml +++ b/examples/mistral/dpo/mistral-dpo-qlora.yml @@ -71,7 +71,6 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: false warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/mistral/lora.yml b/examples/mistral/lora.yml index 757287f19..b157fcc21 100644 --- a/examples/mistral/lora.yml +++ b/examples/mistral/lora.yml @@ -54,7 +54,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mistral-qlora-fsdp.yml b/examples/mistral/mistral-qlora-fsdp.yml index 8e1f03d24..27d8be3cd 100644 --- a/examples/mistral/mistral-qlora-fsdp.yml +++ b/examples/mistral/mistral-qlora-fsdp.yml @@ -51,7 +51,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml b/examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml index dc7bd9c37..1b66de8f0 100644 --- a/examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml +++ b/examples/mistral/mixtral/mixtral-8x22b-qlora-fsdp.yml @@ -49,7 +49,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mixtral/mixtral-qlora-fsdp.yml b/examples/mistral/mixtral/mixtral-qlora-fsdp.yml index 5151e1292..bd7c8620e 100644 --- a/examples/mistral/mixtral/mixtral-qlora-fsdp.yml +++ b/examples/mistral/mixtral/mixtral-qlora-fsdp.yml @@ -51,7 +51,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mixtral/mixtral.yml b/examples/mistral/mixtral/mixtral.yml index d1981a699..b493ed317 100644 --- a/examples/mistral/mixtral/mixtral.yml +++ b/examples/mistral/mixtral/mixtral.yml @@ -69,7 +69,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/mixtral/mixtral_22.yml b/examples/mistral/mixtral/mixtral_22.yml index 0b606b7d7..3b87af04e 100644 --- a/examples/mistral/mixtral/mixtral_22.yml +++ b/examples/mistral/mixtral/mixtral_22.yml @@ -40,7 +40,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 save_total_limit: 1 save_steps: diff --git a/examples/mistral/mps/lora-mps.yml b/examples/mistral/mps/lora-mps.yml index 07ce191dc..1b8021085 100644 --- a/examples/mistral/mps/lora-mps.yml +++ b/examples/mistral/mps/lora-mps.yml @@ -53,8 +53,7 @@ tf32: true gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: false -sdp_attention: true +attn_implementation: sdpa loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/orpo/mistral-qlora-orpo.yml b/examples/mistral/orpo/mistral-qlora-orpo.yml index 850d286f3..d1c0065e5 100644 --- a/examples/mistral/orpo/mistral-qlora-orpo.yml +++ b/examples/mistral/orpo/mistral-qlora-orpo.yml @@ -59,7 +59,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 2a7495e95..4fa82d11e 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -54,7 +54,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 loss_watchdog_threshold: 5.0 loss_watchdog_patience: 3 diff --git a/examples/mistral4/fft-text.yml b/examples/mistral4/fft-text.yml index 3acb5b2ed..2cdab6a42 100644 --- a/examples/mistral4/fft-text.yml +++ b/examples/mistral4/fft-text.yml @@ -40,7 +40,7 @@ bf16: true tf32: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mistral4/fft-vision.yml b/examples/mistral4/fft-vision.yml index baff37fe4..22262c55a 100644 --- a/examples/mistral4/fft-vision.yml +++ b/examples/mistral4/fft-vision.yml @@ -39,7 +39,7 @@ bf16: true tf32: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mistral4/qlora-text.yml b/examples/mistral4/qlora-text.yml index ae0cdcead..887ce6da0 100644 --- a/examples/mistral4/qlora-text.yml +++ b/examples/mistral4/qlora-text.yml @@ -50,7 +50,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/mistral4/qlora-vision.yml b/examples/mistral4/qlora-vision.yml index a80d166dd..d01f8e85b 100644 --- a/examples/mistral4/qlora-vision.yml +++ b/examples/mistral4/qlora-vision.yml @@ -55,7 +55,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/nemotron-h/120b-a12b-qlora.yaml b/examples/nemotron-h/120b-a12b-qlora.yaml index 03e6d3b5e..1174cec21 100644 --- a/examples/nemotron-h/120b-a12b-qlora.yaml +++ b/examples/nemotron-h/120b-a12b-qlora.yaml @@ -72,7 +72,7 @@ gradient_checkpointing_kwargs: resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 2 diff --git a/examples/nemotron-h/nano-30b-a3b-qlora.yaml b/examples/nemotron-h/nano-30b-a3b-qlora.yaml index 3994ab08e..206bd5df8 100644 --- a/examples/nemotron-h/nano-30b-a3b-qlora.yaml +++ b/examples/nemotron-h/nano-30b-a3b-qlora.yaml @@ -73,7 +73,7 @@ gradient_checkpointing_kwargs: resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/nemotron/nemotron-mini-4b-qlora.yaml b/examples/nemotron/nemotron-mini-4b-qlora.yaml index e796c149c..3f3772071 100644 --- a/examples/nemotron/nemotron-mini-4b-qlora.yaml +++ b/examples/nemotron/nemotron-mini-4b-qlora.yaml @@ -48,7 +48,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/olmo3/olmo3-7b-qlora.yaml b/examples/olmo3/olmo3-7b-qlora.yaml index de2bf1d3d..b494699e0 100644 --- a/examples/olmo3/olmo3-7b-qlora.yaml +++ b/examples/olmo3/olmo3-7b-qlora.yaml @@ -55,7 +55,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/orpheus/finetune.yml b/examples/orpheus/finetune.yml index f4bc8054e..86a488c84 100644 --- a/examples/orpheus/finetune.yml +++ b/examples/orpheus/finetune.yml @@ -41,7 +41,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 5 diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index 717a45929..c16b15d8a 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -48,7 +48,7 @@ gradient_checkpointing_kwargs: use_reentrant: True resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index 0fe1abea5..ac4970355 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -51,7 +51,7 @@ gradient_checkpointing_kwargs: use_reentrant: True resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/phi/phi2-ft.yml b/examples/phi/phi2-ft.yml index e470c0d24..5702cc9b8 100644 --- a/examples/phi/phi2-ft.yml +++ b/examples/phi/phi2-ft.yml @@ -48,7 +48,7 @@ gradient_checkpointing_kwargs: use_reentrant: True resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/phi/phi3-ft-fsdp.yml b/examples/phi/phi3-ft-fsdp.yml index 1793737b5..49d3e44cb 100644 --- a/examples/phi/phi3-ft-fsdp.yml +++ b/examples/phi/phi3-ft-fsdp.yml @@ -49,7 +49,7 @@ gradient_checkpointing_kwargs: use_reentrant: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/phi/phi3-ft.yml b/examples/phi/phi3-ft.yml index 0b204963c..d36317f7b 100644 --- a/examples/phi/phi3-ft.yml +++ b/examples/phi/phi3-ft.yml @@ -44,7 +44,7 @@ gradient_checkpointing_kwargs: use_reentrant: True early_stopping_patience: 3 logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 eval_steps: 1000 save_steps: 5000 diff --git a/examples/pixtral/lora-12b.yml b/examples/pixtral/lora-12b.yml index 0e6489914..2e36688a1 100644 --- a/examples/pixtral/lora-12b.yml +++ b/examples/pixtral/lora-12b.yml @@ -45,7 +45,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/plano/plano-4b-qlora.yaml b/examples/plano/plano-4b-qlora.yaml index 106e44205..30e0c36ff 100644 --- a/examples/plano/plano-4b-qlora.yaml +++ b/examples/plano/plano-4b-qlora.yaml @@ -56,7 +56,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qat_nvfp4/Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Gemma3-12B_baseline.yml index be4e86635..e1c7e998a 100644 --- a/examples/qat_nvfp4/Gemma3-12B_baseline.yml +++ b/examples/qat_nvfp4/Gemma3-12B_baseline.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/out_gemma/ sequence_len: 8096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 wandb_entity: wandb_watch: diff --git a/examples/qat_nvfp4/Gemma3-12B_qat.yml b/examples/qat_nvfp4/Gemma3-12B_qat.yml index 7fa81163f..061fd6061 100644 --- a/examples/qat_nvfp4/Gemma3-12B_qat.yml +++ b/examples/qat_nvfp4/Gemma3-12B_qat.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out_gemma/ sequence_len: 8096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: nvfp4 diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml index 9f209515b..f11f604b4 100644 --- a/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml +++ b/examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/out_math_gemma/ sequence_len: 4096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 wandb_entity: wandb_watch: diff --git a/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml index ef7e754be..f9c71321e 100644 --- a/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml +++ b/examples/qat_nvfp4/Math-Gemma3-12B_qat.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out_math_gemma/ sequence_len: 4096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: nvfp4 diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml index 3a262d342..de8bc1807 100644 --- a/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml +++ b/examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/out_math_gemma27/ sequence_len: 4096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 wandb_entity: wandb_watch: diff --git a/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml index 87016ae9c..c77060ee2 100644 --- a/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml +++ b/examples/qat_nvfp4/Math-Gemma3-27B_qat.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out_math_gemma27/ sequence_len: 4096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: nvfp4 diff --git a/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml b/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml index efec25c54..487fc8e4e 100644 --- a/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml +++ b/examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/out_math_72b/ sequence_len: 4096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 wandb_entity: wandb_watch: diff --git a/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml b/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml index 427d7af52..12812d859 100644 --- a/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml +++ b/examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out_math_72b/ sequence_len: 4096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: nvfp4 diff --git a/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml b/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml index e1eaba61f..c52fd6b0a 100644 --- a/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml +++ b/examples/qat_nvfp4/Qwen2.5-72B_baseline.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/out_qwen72b/ sequence_len: 8096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 wandb_entity: wandb_watch: diff --git a/examples/qat_nvfp4/Qwen2.5-72B_qat.yml b/examples/qat_nvfp4/Qwen2.5-72B_qat.yml index dad7e5422..cc67107c0 100644 --- a/examples/qat_nvfp4/Qwen2.5-72B_qat.yml +++ b/examples/qat_nvfp4/Qwen2.5-72B_qat.yml @@ -24,7 +24,7 @@ output_dir: ./outputs/qat_out_qwen72b/ sequence_len: 8096 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 qat: activation_dtype: nvfp4 diff --git a/examples/qwen2-vl/lora-7b.yaml b/examples/qwen2-vl/lora-7b.yaml index 285a35cbb..d9bc4826b 100644 --- a/examples/qwen2-vl/lora-7b.yaml +++ b/examples/qwen2-vl/lora-7b.yaml @@ -46,8 +46,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qwen2/adamw-pretrain-fsdp2.yaml b/examples/qwen2/adamw-pretrain-fsdp2.yaml index 43fb17aab..4129338db 100644 --- a/examples/qwen2/adamw-pretrain-fsdp2.yaml +++ b/examples/qwen2/adamw-pretrain-fsdp2.yaml @@ -49,7 +49,7 @@ tf32: false gradient_checkpointing: false logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_steps: 10 evals_per_epoch: 0 diff --git a/examples/qwen2/dpo.yaml b/examples/qwen2/dpo.yaml index 3e87766d6..6096053fd 100644 --- a/examples/qwen2/dpo.yaml +++ b/examples/qwen2/dpo.yaml @@ -48,7 +48,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen2/muon-pretrain-fsdp2.yaml b/examples/qwen2/muon-pretrain-fsdp2.yaml index 35c0b71f4..40dcff7be 100644 --- a/examples/qwen2/muon-pretrain-fsdp2.yaml +++ b/examples/qwen2/muon-pretrain-fsdp2.yaml @@ -49,7 +49,7 @@ tf32: false gradient_checkpointing: false logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_steps: 10 evals_per_epoch: 0 diff --git a/examples/qwen2/prm.yaml b/examples/qwen2/prm.yaml index a709a598d..1b3579fd4 100644 --- a/examples/qwen2/prm.yaml +++ b/examples/qwen2/prm.yaml @@ -47,7 +47,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/qwen2/qlora-fsdp.yaml b/examples/qwen2/qlora-fsdp.yaml index 337619b61..7bb035c3a 100644 --- a/examples/qwen2/qlora-fsdp.yaml +++ b/examples/qwen2/qlora-fsdp.yaml @@ -47,7 +47,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen2/reward-model.yaml b/examples/qwen2/reward-model.yaml index 08b8b4552..b7039cba0 100644 --- a/examples/qwen2/reward-model.yaml +++ b/examples/qwen2/reward-model.yaml @@ -42,7 +42,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml index 7d499d841..e78aac78b 100644 --- a/examples/qwen2_5-vl/lora-7b.yaml +++ b/examples/qwen2_5-vl/lora-7b.yaml @@ -46,8 +46,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml index f63b1d1ce..e8e7e08c7 100644 --- a/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml +++ b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml @@ -68,7 +68,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml b/examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml index f66bcd370..47842c561 100644 --- a/examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml +++ b/examples/qwen3.5/122b-a10b-moe-qlora-fsdp.yaml @@ -65,7 +65,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3.5/122b-a10b-moe-qlora.yaml b/examples/qwen3.5/122b-a10b-moe-qlora.yaml index 4447cf73c..f2675c7d7 100644 --- a/examples/qwen3.5/122b-a10b-moe-qlora.yaml +++ b/examples/qwen3.5/122b-a10b-moe-qlora.yaml @@ -65,7 +65,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3.5/27b-fft.yaml b/examples/qwen3.5/27b-fft.yaml index 9f875ec26..ab206b772 100644 --- a/examples/qwen3.5/27b-fft.yaml +++ b/examples/qwen3.5/27b-fft.yaml @@ -50,7 +50,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3.5/27b-qlora-fsdp.yaml b/examples/qwen3.5/27b-qlora-fsdp.yaml index 79b87a32f..7a5423c77 100644 --- a/examples/qwen3.5/27b-qlora-fsdp.yaml +++ b/examples/qwen3.5/27b-qlora-fsdp.yaml @@ -61,7 +61,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3.5/27b-qlora.yaml b/examples/qwen3.5/27b-qlora.yaml index 18c0af95b..2401a4865 100644 --- a/examples/qwen3.5/27b-qlora.yaml +++ b/examples/qwen3.5/27b-qlora.yaml @@ -61,7 +61,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml b/examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml index ad17366cb..2fb7f15f8 100644 --- a/examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml +++ b/examples/qwen3.5/35b-a3b-moe-qlora-fsdp.yaml @@ -65,7 +65,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3.5/35b-a3b-moe-qlora.yaml b/examples/qwen3.5/35b-a3b-moe-qlora.yaml index 22468a178..a6afc1aa2 100644 --- a/examples/qwen3.5/35b-a3b-moe-qlora.yaml +++ b/examples/qwen3.5/35b-a3b-moe-qlora.yaml @@ -75,7 +75,7 @@ gradient_checkpointing: true activation_offloading: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3.5/35b-a3b-moe-vision-lora.yaml b/examples/qwen3.5/35b-a3b-moe-vision-lora.yaml index a7c85f785..7cfad3290 100644 --- a/examples/qwen3.5/35b-a3b-moe-vision-lora.yaml +++ b/examples/qwen3.5/35b-a3b-moe-vision-lora.yaml @@ -50,7 +50,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 weight_decay: 0.0 diff --git a/examples/qwen3.5/9b-fft-vision.yaml b/examples/qwen3.5/9b-fft-vision.yaml index b6aeb859d..e8427b884 100644 --- a/examples/qwen3.5/9b-fft-vision.yaml +++ b/examples/qwen3.5/9b-fft-vision.yaml @@ -40,7 +40,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qwen3.5/9b-lora-vision.yaml b/examples/qwen3.5/9b-lora-vision.yaml index 1c3717724..9c2b9397e 100644 --- a/examples/qwen3.5/9b-lora-vision.yaml +++ b/examples/qwen3.5/9b-lora-vision.yaml @@ -58,7 +58,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/qwen3/32b-qlora.yaml b/examples/qwen3/32b-qlora.yaml index f4a4f2816..dd5dd696e 100644 --- a/examples/qwen3/32b-qlora.yaml +++ b/examples/qwen3/32b-qlora.yaml @@ -60,7 +60,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/qwen3/8b-qat-fsdp2.yml b/examples/qwen3/8b-qat-fsdp2.yml index cfbe5a4b7..3c9607a9a 100644 --- a/examples/qwen3/8b-qat-fsdp2.yml +++ b/examples/qwen3/8b-qat-fsdp2.yml @@ -23,7 +23,7 @@ output_dir: ./outputs/qat_out/ sequence_len: 2048 sample_packing: true -flex_attention: true +attn_implementation: flex_attention flex_attn_compile_kwargs: diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml index e4d584dc7..a3852d457 100644 --- a/examples/qwen3/qlora-fsdp.yaml +++ b/examples/qwen3/qlora-fsdp.yaml @@ -46,7 +46,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 4 diff --git a/examples/seed-oss/seed-oss-36b-qlora.yaml b/examples/seed-oss/seed-oss-36b-qlora.yaml index 00e7cf3eb..a8423f851 100644 --- a/examples/seed-oss/seed-oss-36b-qlora.yaml +++ b/examples/seed-oss/seed-oss-36b-qlora.yaml @@ -47,7 +47,7 @@ tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/smolvlm2/smolvlm2-2B-lora.yaml b/examples/smolvlm2/smolvlm2-2B-lora.yaml index 1aeff408d..4cd8d5b0d 100644 --- a/examples/smolvlm2/smolvlm2-2B-lora.yaml +++ b/examples/smolvlm2/smolvlm2-2B-lora.yaml @@ -45,8 +45,7 @@ tf32: true gradient_checkpointing: true logging_steps: 1 -flash_attention: true -eager_attention: +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/streaming/pretrain.yaml b/examples/streaming/pretrain.yaml index bc8edefd6..a0d8b17c0 100644 --- a/examples/streaming/pretrain.yaml +++ b/examples/streaming/pretrain.yaml @@ -20,7 +20,7 @@ output_dir: ./outputs/smollm2-135m-pretrain-streaming sequence_len: 1024 sample_packing: true pretrain_multipack_attn: true # Prevent cross-attention between packed sequences -flash_attention: true +attn_implementation: flash_attention_2 # Batch size settings gradient_accumulation_steps: 8 diff --git a/examples/streaming/sft.yaml b/examples/streaming/sft.yaml index 47b9f493f..4a43c34eb 100644 --- a/examples/streaming/sft.yaml +++ b/examples/streaming/sft.yaml @@ -18,7 +18,7 @@ output_dir: ./outputs/smollm2-135m-sft-streaming # Sequence and packing settings sequence_len: 1024 sample_packing: true -flash_attention: true +attn_implementation: flash_attention_2 # Batch size settings gradient_accumulation_steps: 4 diff --git a/examples/swanlab/dpo-swanlab-completions.yml b/examples/swanlab/dpo-swanlab-completions.yml index 5615ca638..fb21dbbba 100644 --- a/examples/swanlab/dpo-swanlab-completions.yml +++ b/examples/swanlab/dpo-swanlab-completions.yml @@ -78,7 +78,7 @@ tf32: false # Performance gradient_checkpointing: true -flash_attention: true +attn_implementation: flash_attention_2 # Checkpointing and Logging logging_steps: 1 diff --git a/examples/swanlab/dpo-swanlab-full-featured.yml b/examples/swanlab/dpo-swanlab-full-featured.yml index c25178c63..ac52e6a85 100644 --- a/examples/swanlab/dpo-swanlab-full-featured.yml +++ b/examples/swanlab/dpo-swanlab-full-featured.yml @@ -102,7 +102,7 @@ bf16: auto tf32: false gradient_checkpointing: true -flash_attention: true +attn_implementation: flash_attention_2 # ============================================================================ # Checkpointing and Logging diff --git a/examples/swanlab/lora-swanlab-profiling.yml b/examples/swanlab/lora-swanlab-profiling.yml index 1255105a6..3dff6e315 100644 --- a/examples/swanlab/lora-swanlab-profiling.yml +++ b/examples/swanlab/lora-swanlab-profiling.yml @@ -59,7 +59,7 @@ tf32: false # Performance gradient_checkpointing: true -flash_attention: true +attn_implementation: flash_attention_2 # Checkpointing and Logging logging_steps: 1 diff --git a/examples/trinity/trinity-nano-preview-qlora.yaml b/examples/trinity/trinity-nano-preview-qlora.yaml index d8bf9f073..52c0c0c60 100644 --- a/examples/trinity/trinity-nano-preview-qlora.yaml +++ b/examples/trinity/trinity-nano-preview-qlora.yaml @@ -58,7 +58,7 @@ gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 # flash_attention: true # Not supported -sdp_attention: true +attn_implementation: sdpa warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/voxtral/voxtral-mini-audio-qlora.yml b/examples/voxtral/voxtral-mini-audio-qlora.yml index 59150c4ca..cfa351ccd 100644 --- a/examples/voxtral/voxtral-mini-audio-qlora.yml +++ b/examples/voxtral/voxtral-mini-audio-qlora.yml @@ -70,7 +70,7 @@ gradient_checkpointing: true gradient_checkpointing_kwargs: use_reentrant: false logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: 1 diff --git a/examples/voxtral/voxtral-mini-qlora.yml b/examples/voxtral/voxtral-mini-qlora.yml index bdbc5f867..61e8933d0 100644 --- a/examples/voxtral/voxtral-mini-qlora.yml +++ b/examples/voxtral/voxtral-mini-qlora.yml @@ -64,7 +64,7 @@ gradient_checkpointing_kwargs: use_reentrant: false resume_from_checkpoint: logging_steps: 1 -flash_attention: true +attn_implementation: flash_attention_2 warmup_ratio: 0.1 evals_per_epoch: diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index 00e5303bd..e111ca9f9 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -147,7 +147,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: load_in_8bit=False, load_in_4bit=False, quantize_moe_experts=False, - flash_attention=False, + attn_implementation=None, context_parallel_size=None, deepspeed=None, fsdp=None, diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index fe832dd45..15624173d 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -257,19 +257,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing) - training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool( - self.cfg.flash_attention - or self.cfg.xformers_attention - or self.cfg.flex_attention + training_arguments_kwargs["sample_packing_drop_attention_mask"] = ( + self.cfg.attn_supports_packing ) training_arguments_kwargs["multipack_real_batches"] = ( self.cfg.multipack_real_batches if self.cfg.multipack_real_batches is not None - else not ( - self.cfg.flash_attention - or self.cfg.flex_attention - or self.cfg.xformers_attention - ) + else not self.cfg.attn_supports_packing ) training_arguments_kwargs["eval_sample_packing"] = bool( self.cfg.eval_sample_packing @@ -508,11 +502,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, # supported multipack models, or non-flash-attention llama if ( - self.cfg.flex_attention + self.cfg.attn_implementation == "flex_attention" or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or ( self.cfg.model_config_type in ["llama"] - and self.cfg.flash_attention is not True + and self.cfg.attn_implementation != "flash_attention_2" ) ): collator = V2BatchSamplerDataCollatorForSeq2Seq diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index 6a82dd6cf..732ce1592 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -23,7 +23,7 @@ class LMEvalPlugin(BasePlugin): for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, bfloat16=cfg.bfloat16 or cfg.bf16, - flash_attention=cfg.flash_attention, + flash_attention=cfg.attn_uses_flash_lib, output_dir=cfg.output_dir, batch_size=cfg.lm_eval_batch_size, wandb_project=cfg.wandb_project, diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 4b905d476..a20f4d154 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -114,10 +114,18 @@ def lm_eval(config: str, cloud: Optional[str] = None): with open(config, encoding="utf-8") as file: cfg: DictDefault = DictDefault(yaml.safe_load(file)) + # This path operates on raw YAML via DictDefault (not the validated + # AxolotlInputConfig), so we resolve flash-attn from either the canonical + # `attn_implementation` field or the deprecated `flash_attention` boolean. + _flash_attn_impls = {"flash_attention_2", "flash_attention_3"} + lm_eval_flash_attention = bool( + cfg.flash_attention or cfg.attn_implementation in _flash_attn_impls + ) + for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, bfloat16=cfg.bfloat16 or cfg.bf16, - flash_attention=cfg.flash_attention, + flash_attention=lm_eval_flash_attention, output_dir=cfg.output_dir, batch_size=cfg.lm_eval_batch_size, wandb_project=cfg.wandb_project, diff --git a/src/axolotl/integrations/swanlab/plugins.py b/src/axolotl/integrations/swanlab/plugins.py index 16218d39d..55f19ac59 100644 --- a/src/axolotl/integrations/swanlab/plugins.py +++ b/src/axolotl/integrations/swanlab/plugins.py @@ -383,7 +383,9 @@ class SwanLabPlugin(BasePlugin): "seed": safe_convert(getattr(cfg, "seed", None)), "bf16": safe_convert(getattr(cfg, "bf16", None)), "tf32": safe_convert(getattr(cfg, "tf32", None)), - "flash_attention": safe_convert(getattr(cfg, "flash_attention", None)), + "attn_implementation": safe_convert( + getattr(cfg, "attn_implementation", None) + ), "sample_packing": safe_convert(getattr(cfg, "sample_packing", None)), } diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 57aabfbab..061509d39 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -343,12 +343,7 @@ class ModelLoader: # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so # we need to convert them back to fp16/bf16 for flash-attn compatibility. ( - ( - needs_fa2_dtype - or self.cfg.flash_attention - or self.cfg.flex_attention - or self.cfg.sage_attention - ) + (needs_fa2_dtype or self.cfg.attn_needs_dtype_cast) and not self.is_qlora_and_fsdp_enabled ) or ( @@ -633,35 +628,14 @@ class ModelLoader: ) def _set_attention_config(self): - """Sample packing uses custom FA2 patch""" - if self.cfg.gemma4_hybrid_attn_impl: - # Load model with flash_attention_2 for sliding window layers; - # global layers will be patched to sdpa post-load. - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = "flash_attention_2" - # Set flash_attention so multipack/sample_packing patches activate - self.cfg.flash_attention = True - elif self.cfg.attn_implementation: - self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation - elif self.cfg.flex_attention: - self.model_kwargs["attn_implementation"] = "flex_attention" - self.model_config._attn_implementation = "flex_attention" - - elif self.cfg.flash_attention: - if not self.cfg.sample_packing and self.cfg.s2_attention: - pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = "flash_attention_2" - elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = "sdpa" - elif self.cfg.sage_attention: - # sets FA2 attention to re-use same internal handling like masking - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = "flash_attention_2" - elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" - self.model_config._attn_implementation = "eager" + # s2 patches FA2 internals (load as FA2); fp8 replaces sdpa post-load (load as sdpa). + _LOAD_TIME_OVERRIDE = {"s2": "flash_attention_2", "fp8": "sdpa"} + if self.cfg.attn_implementation: + hf_impl = _LOAD_TIME_OVERRIDE.get( + self.cfg.attn_implementation, self.cfg.attn_implementation + ) + self.model_kwargs["attn_implementation"] = hf_impl + self.model_config._attn_implementation = hf_impl if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 01d9997d7..68952014f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -172,6 +172,7 @@ class PatchManager: self._apply_llama_flash_attn_patches(model) self._apply_lora_kernel_patch(model) self._apply_scaling_softmax_patch(model) + self._apply_fp8_attention_patches(model) def _apply_gemma_hybrid_attention(self, model: PreTrainedModel): """Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers. @@ -252,11 +253,28 @@ class PatchManager: def _apply_flash_attention_patches(self): """Apply patches related to Flash Attention.""" - if self.cfg.xformers_attention and self.cfg.sample_packing: - from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + if self.cfg.attn_implementation == "xformers": + from axolotl.monkeypatch.attention import register_xformers_attn - patch_xformers_attn_over_fa2() - self.cfg.flash_attention = True + register_xformers_attn() + + if self.cfg.sample_packing: + # Also patch FA2 slot for legacy code paths that use it directly + from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + + patch_xformers_attn_over_fa2() + + if self.cfg.attn_implementation == "sage": + from axolotl.monkeypatch.attention import register_sage_attn + + register_sage_attn() + + def _apply_fp8_attention_patches(self, model): + """Apply FP8 low-precision attention via torchao.""" + if self.cfg.attn_implementation == "fp8": + from axolotl.monkeypatch.attention.fp8_attn import patch_fp8_attention + + patch_fp8_attention(model) def _apply_chunked_cross_entropy_patch(self): if self.cfg.chunked_cross_entropy: @@ -315,7 +333,7 @@ class PatchManager: def _apply_flex_attention_patches(self): """Apply patches for flexible attention.""" - if self.cfg.flex_attention: + if self.cfg.attn_implementation == "flex_attention": from axolotl.monkeypatch.attention.flex_attn import ( patch_flex_wrapper, ) @@ -325,14 +343,14 @@ class PatchManager: def _apply_sageattn_patches(self): """Apply patches for SageAttention.""" - if self.cfg.sage_attention: + if self.cfg.attn_implementation == "sage": from axolotl.monkeypatch.attention.sage_attn import patch_sageattn patch_sageattn() def _apply_flash_attn_4_patches(self): """Auto-apply FA4 when flash_attention is enabled and FA4 is available on SM90+.""" - if not self.cfg.flash_attention: + if not self.cfg.attn_uses_flash_lib: return from axolotl.monkeypatch.attention.flash_attn_4 import patch_flash_attn_4 @@ -401,7 +419,7 @@ class PatchManager: if ( self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"] and self.cfg.is_multimodal - and self.cfg.flash_attention + and self.cfg.attn_uses_flash_lib ): from axolotl.monkeypatch.models.qwen3_5.modeling import ( patch_qwen3_5_vlm_flash_attention, @@ -553,7 +571,7 @@ class PatchManager: """Apply multipack patches if necessary.""" if ( self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.attn_supports_packing and self.cfg.sample_packing ): # Get automap config if it exists @@ -674,7 +692,9 @@ class PatchManager: def _patch_attention(self): """Apply attention-specific patches based on model type.""" - if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): + if not ( + self.cfg.attn_uses_flash_lib and hasattr(self.model_config, "model_type") + ): return if self.model_config.model_type == "btlm": @@ -720,7 +740,7 @@ class PatchManager: replace_llama_attn_with_flash_attn, ) - if self.cfg.s2_attention: + if self.cfg.attn_implementation == "s2": LOG.info("patching w/ flash-enabled, shifted-sparse attention") replace_llama_attn_with_flash_attn( cross_entropy=self.cfg.flash_attn_cross_entropy, @@ -746,14 +766,14 @@ class PatchManager: """Modify all llama derived models in one block.""" if self.cfg.is_llama_derived_model and not ( self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and (self.cfg.flash_attention or self.cfg.flex_attention) + and self.cfg.attn_supports_packing and self.cfg.sample_packing ): - if self.cfg.flash_attention: + if self.cfg.attn_uses_flash_lib: self._patch_llama_flash_attention() - elif self.cfg.xformers_attention: + elif self.cfg.attn_implementation == "xformers": self._patch_llama_xformers_attention() - elif self.cfg.s2_attention: + elif self.cfg.attn_implementation == "s2": raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." ) @@ -765,7 +785,7 @@ class PatchManager: in ["llama", "llama4", "ernie4_5", "ernie4_5_moe"] and not self.cfg.trust_remote_code and not self.cfg.gptq - and self.cfg.flash_attention + and self.cfg.attn_uses_flash_lib and is_flash_attn_available() and not self.inference ): diff --git a/src/axolotl/loaders/tokenizer.py b/src/axolotl/loaders/tokenizer.py index 52f714604..572a880bd 100644 --- a/src/axolotl/loaders/tokenizer.py +++ b/src/axolotl/loaders/tokenizer.py @@ -205,7 +205,11 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer: os.environ["TOKENIZERS_PARALLELISM"] = "false" # Mistral's official FA implementation requires left padding - if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: + if ( + cfg.is_mistral_derived_model + and cfg.attn_implementation == "flash_attention_2" + and not cfg.sample_packing + ): tokenizer.padding_side = "left" # Qwen base only has single token, so we need to set the special tokens diff --git a/src/axolotl/monkeypatch/attention/__init__.py b/src/axolotl/monkeypatch/attention/__init__.py index 15ed764f4..74bd61e77 100644 --- a/src/axolotl/monkeypatch/attention/__init__.py +++ b/src/axolotl/monkeypatch/attention/__init__.py @@ -17,3 +17,29 @@ def unpatch_xformers_attn_over_fa2(): from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward() + + +def register_xformers_attn(): + """Register xformers as its own attention backend with FA2 mask behavior.""" + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from .xformers import xformers_attention_forward + + ALL_ATTENTION_FUNCTIONS.register("xformers", xformers_attention_forward) + ALL_MASK_ATTENTION_FUNCTIONS.register( + "xformers", ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) + + +def register_sage_attn(): + """Register sage as its own attention backend with FA2 mask behavior.""" + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from .sage_attn import sage_attention_forward + + ALL_ATTENTION_FUNCTIONS.register("sage", sage_attention_forward) + ALL_MASK_ATTENTION_FUNCTIONS.register( + "sage", ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) diff --git a/src/axolotl/monkeypatch/attention/fp8_attn.py b/src/axolotl/monkeypatch/attention/fp8_attn.py new file mode 100644 index 000000000..224e8c3b7 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/fp8_attn.py @@ -0,0 +1,30 @@ +"""FP8 low-precision attention via torchao. + +Requires: + - PyTorch >= 2.11.0 + - SM90+ (Hopper/Blackwell) GPU + - flash-attn package with FA3 support + - torchao >= 0.17.0 + +Uses per-head FP8 quantized attention with automatic RoPE fusion under torch.compile. +The torchao patch replaces F.scaled_dot_product_attention, so the model must use +HF's "sdpa" attention implementation for the patch to intercept attention calls. +""" + +import logging + +import torch + +LOG = logging.getLogger(__name__) + + +def patch_fp8_attention(model: torch.nn.Module) -> torch.nn.Module: + """Apply FP8 low-precision attention to a model. + + Must be called after model loading and before torch.compile. + KV caching should be disabled (config.use_cache = False). + """ + from torchao.prototype.attention import apply_low_precision_attention + + LOG.info("Applying FP8 low-precision attention (torchao)") + return apply_low_precision_attention(model) diff --git a/src/axolotl/monkeypatch/attention/sage_attn.py b/src/axolotl/monkeypatch/attention/sage_attn.py index cc9fdb94d..6e9ba0f85 100644 --- a/src/axolotl/monkeypatch/attention/sage_attn.py +++ b/src/axolotl/monkeypatch/attention/sage_attn.py @@ -191,21 +191,9 @@ def sage_attention_forward( def patch_sageattn(): - """Patch SageAttention for use with transformers.""" + """Validate SageAttention is available. Registration in the attention/mask + function registries is handled by register_sage_attn() in __init__.py.""" _check_sageattn_imported() - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS - - # Replace flash attention with sage attention - ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward) - - # Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS - # Register sage_attention with the global attention interface - # ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward) - - # from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask - - # ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask) - - LOG.info("SageAttention patched successfully") + LOG.info("SageAttention validated successfully") diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 8137bac0c..edb61441b 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -955,7 +955,10 @@ def colab_inference_post_train_callback(trainer: Trainer): """ handle T4 gpu, we need to convert attention to eager for inference """ - if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention: + if ( + "Tesla T4" in self.gpu_name + and self.cfg.attn_implementation == "xformers" + ): trainer.model.config._attn_implementation = "eager" trainer.model.gradient_checkpointing_disable() trainer.model.config.use_cache = True diff --git a/src/axolotl/utils/collators/mm_chat.py b/src/axolotl/utils/collators/mm_chat.py index b81612cbc..12dbde7d1 100644 --- a/src/axolotl/utils/collators/mm_chat.py +++ b/src/axolotl/utils/collators/mm_chat.py @@ -43,16 +43,16 @@ class MultiModalChatDataCollator(DataCollatorMixin): # Initialize batch messages = [ex["messages"] for ex in examples] - batch = self.processing_strategy.processor.apply_chat_template( - messages, - add_generation_prompt=False, - tokenize=True, - chat_template=self.processing_strategy.chat_template, - processor_kwargs={ - "return_tensors": "pt", - "padding": True, - "return_dict": True, - }, + batch = dict( + self.processing_strategy.processor.apply_chat_template( + messages, + add_generation_prompt=False, + tokenize=True, + return_tensors="pt", + padding=True, + return_dict=True, + chat_template=self.processing_strategy.chat_template, + ) ) # Process the labels diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c52ddce1a..9874b9da5 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -10,7 +10,9 @@ from pydantic import ( BaseModel, Field, StringConstraints, + computed_field, field_serializer, + field_validator, model_validator, ) @@ -27,7 +29,17 @@ from axolotl.utils.schemas.datasets import ( ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters from axolotl.utils.schemas.dynamic_checkpoint import DynamicCheckpointConfig -from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +from axolotl.utils.schemas.enums import ( + ATTN_IMPLS_SUPPORTING_PACKING, + ATTN_IMPLS_USING_FLASH_LIB, + ATTN_IMPLS_WITHOUT_DTYPE_CAST, + CANONICAL_ATTN_IMPLS, + LEGACY_ATTN_FLAG_TO_IMPL, + SHORT_FORM_ALIAS_TO_CANONICAL, + ChatTemplate, + RingAttnFunc, + RLType, +) from axolotl.utils.schemas.fsdp import FSDPConfig from axolotl.utils.schemas.integrations import ( CometConfig, @@ -731,28 +743,35 @@ class AxolotlInputConfig( xformers_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: xformers` instead.", json_schema_extra={ - "description": "Whether to use xformers attention patch https://github.com/facebookresearch/xformers" + "description": "[DEPRECATED] Use `attn_implementation: xformers`. https://github.com/facebookresearch/xformers" }, ) sdp_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: sdpa` instead.", json_schema_extra={ - "description": "Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html" + "description": "[DEPRECATED] Use `attn_implementation: sdpa`." }, ) s2_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: s2` instead.", json_schema_extra={ - "description": "Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf" + "description": "[DEPRECATED] Use `attn_implementation: s2`. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf" }, ) - flex_attention: bool | None = None + flex_attention: bool | None = Field( + default=None, + deprecated="Use `attn_implementation: flex_attention` instead.", + ) flex_attn_compile_kwargs: dict[str, Any] | None = None flash_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: flash_attention_2` instead.", json_schema_extra={ - "description": "Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention" + "description": "[DEPRECATED] Use `attn_implementation: flash_attention_2`. https://github.com/Dao-AILab/flash-attention" }, ) flash_attn_cross_entropy: bool | None = Field( @@ -779,17 +798,26 @@ class AxolotlInputConfig( ) sage_attention: bool | None = Field( default=None, + deprecated="Use `attn_implementation: sage` instead.", json_schema_extra={ - "description": "Whether to use SageAttention https://github.com/thu-ml/SageAttention" + "description": "[DEPRECATED] Use `attn_implementation: sage`. https://github.com/thu-ml/SageAttention" }, ) - eager_attention: bool | None = None + eager_attention: bool | None = Field( + default=None, + deprecated="Use `attn_implementation: eager` instead.", + ) attn_implementation: str | None = Field( default=None, json_schema_extra={ - "description": "Specify a custom attention implementation, used mostly for kernels." + "description": ( + "Attention backend. Canonical values: eager, sdpa, flash_attention_2, " + "flash_attention_3, flex_attention, xformers, sage, s2, fp8. Hub-kernel " + "paths (e.g. kernels-community/flash-attn3) are also accepted and passed " + "through to transformers." + ) }, ) @@ -1327,6 +1355,25 @@ class AxolotlInputConfig( return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None + # --- Attention capability flags (derived from attn_implementation) --- + + @computed_field # type: ignore[misc] + @property + def attn_supports_packing(self) -> bool: + return self.attn_implementation in ATTN_IMPLS_SUPPORTING_PACKING + + @computed_field # type: ignore[misc] + @property + def attn_uses_flash_lib(self) -> bool: + return self.attn_implementation in ATTN_IMPLS_USING_FLASH_LIB + + @computed_field # type: ignore[misc] + @property + def attn_needs_dtype_cast(self) -> bool: + if self.attn_implementation is None: + return False + return self.attn_implementation not in ATTN_IMPLS_WITHOUT_DTYPE_CAST + @model_validator(mode="before") @classmethod def warn_peft_trainable_token_to_fix_untrained(cls, data): @@ -1349,24 +1396,104 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod - def check_sageattn_wo_sample_packing(cls, data): - if (not data.get("sample_packing", False)) and data.get("sage_attention"): - if not data.get("pad_to_sequence_len", False): - LOG.warning( - "We recommend turning on `pad_to_sequence_len` for SageAttention without packing." - "This is because there has been signs that the loss explodes after a few steps." + def normalize_attn_implementation(cls, data): + """Map legacy boolean attention flags to canonical attn_implementation, warn, then strip.""" + if not isinstance(data, dict): + return data + + attn_impl = data.get("attn_implementation") + set_flags = [f for f in LEGACY_ATTN_FLAG_TO_IMPL if data.get(f)] + + # gemma4_hybrid requires flash_attention_2 for the sliding-window layers; + # post-load patching swaps global layers to sdpa (see + # `_apply_gemma_hybrid_attention`). Default it in when the user didn't + # pick a backend; reject any incompatible explicit choice. + if data.get("gemma4_hybrid_attn_impl"): + if not attn_impl and not set_flags: + data["attn_implementation"] = "flash_attention_2" + attn_impl = "flash_attention_2" + elif attn_impl and attn_impl != "flash_attention_2": + raise ValueError( + f"gemma4_hybrid_attn_impl requires attn_implementation=" + f"flash_attention_2 (sliding-window layers run under FA2); " + f"got {attn_impl!r}." ) + + if attn_impl and set_flags: + raise ValueError( + f"attn_implementation={attn_impl!r} cannot be combined with legacy " + f"attention flags ({', '.join(sorted(set_flags))}). The legacy " + f"flags are deprecated — set only `attn_implementation`." + ) + + if not attn_impl and set_flags: + # Priority: specific backends beat generic flash/sdp/eager fallbacks. + for flag in LEGACY_ATTN_FLAG_TO_IMPL: + if flag in set_flags: + canonical = LEGACY_ATTN_FLAG_TO_IMPL[flag] + data["attn_implementation"] = canonical + LOG.warning( + "`%s: true` is deprecated and will be removed in a future " + "release. Use `attn_implementation: %s` instead.", + flag, + canonical, + ) + break + + # Strip legacy flags from validated data — canonical field is authoritative. + for flag in LEGACY_ATTN_FLAG_TO_IMPL: + data.pop(flag, None) + return data - @model_validator(mode="before") + @field_validator("attn_implementation", mode="before") @classmethod - def check_sageattn_fft(cls, data): - if (not data.get("adapter", False)) and data.get("sage_attention"): - LOG.warning( - "We found loss to drop to 0 with SageAttention full finetuning." - "Please observe the loss, otherwise switch to LoRA/QLoRA or another attention method." + def validate_attn_implementation(cls, value): + """Accept canonical names and hub-kernel paths; reject short-form aliases.""" + if value is None: + return None + if not isinstance(value, str): + raise TypeError( + f"attn_implementation must be a string, got {type(value).__name__}" ) - return data + if value in CANONICAL_ATTN_IMPLS: + return value + if "/" in value: + # Hub-kernel path, e.g. "kernels-community/flash-attn3". Pass through. + return value + if value in SHORT_FORM_ALIAS_TO_CANONICAL: + canonical = SHORT_FORM_ALIAS_TO_CANONICAL[value] + raise ValueError( + f"attn_implementation={value!r} is not accepted. " + f"Use the canonical name {canonical!r} instead." + ) + raise ValueError( + f"attn_implementation={value!r} is not a recognized backend. " + f"Expected one of: {sorted(CANONICAL_ATTN_IMPLS)}, or a hub-kernel " + f"path containing '/'." + ) + + @model_validator(mode="after") + def check_sageattn_wo_sample_packing(self): + if ( + self.attn_implementation == "sage" + and not self.sample_packing + and not self.pad_to_sequence_len + ): + LOG.warning( + "We recommend turning on `pad_to_sequence_len` for SageAttention " + "without packing. The loss has been observed to explode otherwise." + ) + return self + + @model_validator(mode="after") + def check_sageattn_fft(self): + if self.attn_implementation == "sage" and not self.adapter: + LOG.warning( + "SageAttention full finetuning has been observed to drop loss to 0. " + "Monitor the loss, or switch to LoRA/QLoRA or another attention method." + ) + return self @model_validator(mode="before") @classmethod @@ -1442,17 +1569,13 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return self - @model_validator(mode="before") - @classmethod - def check_sample_packing_w_sdpa_bf16(cls, data): - is_sm_90: bool = ( - data["capabilities"] - and data["capabilities"].get("compute_capability") == "sm_90" - ) + @model_validator(mode="after") + def check_sample_packing_w_sdpa_bf16(self): + is_sm_90 = self.capabilities and self.capabilities.compute_capability == "sm_90" if ( - data.get("sample_packing") - and data.get("sdp_attention") - and (data.get("bfloat16") or data.get("bf16")) + self.sample_packing + and self.attn_implementation == "sdpa" + and (self.bfloat16 or self.bf16) and not is_sm_90 ): # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450 @@ -1460,23 +1583,51 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. " "This may work on H100s." ) + return self - return data - - @model_validator(mode="before") - @classmethod - def check_compute_capability_w_sageattn(cls, data): + @model_validator(mode="after") + def check_compute_capability_w_sageattn(self): if ( - data.get("sage_attention") - and data.get("capabilities") - and data.get("capabilities").get("compute_capability") + self.attn_implementation == "sage" + and self.capabilities + and self.capabilities.compute_capability not in ["sm_80", "sm_86", "sm_89", "sm_90", "sm_120"] ): raise ValueError( "SageAttention supports compute capability between sm_80 and sm_120. " "Please use a different attention implementation." ) - return data + return self + + @model_validator(mode="after") + def check_fp8_attention_preflight(self): + """fp8 attention requires SM90+ and torch >= 2.11 (torchao >= 0.17 is pinned).""" + if self.attn_implementation != "fp8": + return self + + if self.capabilities and self.capabilities.compute_capability: + cc = self.capabilities.compute_capability + # Accept sm_90 (H100/H200), sm_100 (B100/B200), sm_120 (B300-class). + if not cc.startswith("sm_") or int(cc.split("_", 1)[1]) < 90: + raise ValueError( + f"attn_implementation=fp8 requires compute capability sm_90 or " + f"higher (Hopper+). Detected {cc!r}." + ) + + torch_version = ( + self.env_capabilities.torch_version if self.env_capabilities else None + ) + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + if version.parse(torch_version) < version.parse("2.11.0"): + raise ValueError( + f"attn_implementation=fp8 requires PyTorch >= 2.11.0. " + f"Detected {torch_version}." + ) + + return self @model_validator(mode="before") @classmethod @@ -1632,13 +1783,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return data - @model_validator(mode="before") - @classmethod - def check_flex_torch_version(cls, data): - if (data.get("flex_attention") is not None) and (data.get("flex_attention")): - env_capabilities = data.get("env_capabilities", {}) - torch_version = env_capabilities.get("torch_version") - + @model_validator(mode="after") + def check_flex_torch_version(self): + if self.attn_implementation == "flex_attention": + torch_version = ( + self.env_capabilities.torch_version if self.env_capabilities else None + ) if torch_version is None: import torch @@ -1648,7 +1798,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): raise ValueError( "Flex attention is not supported on torch version < 2.6.0" ) - return data + return self @model_validator(mode="before") @classmethod diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index d4ff27ac9..801548645 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -97,6 +97,68 @@ class CustomSupportedOptimizers(str, Enum): flash_lion = "flash_lion" +# Accepted canonical names; hub-kernel paths (containing "/") bypass this set. +CANONICAL_ATTN_IMPLS = frozenset( + { + "eager", + "sdpa", + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + "s2", + "fp8", + } +) + +# Legacy boolean flags → canonical attn_implementation. Priority: specific before generic. +LEGACY_ATTN_FLAG_TO_IMPL = { + "xformers_attention": "xformers", + "s2_attention": "s2", + "sage_attention": "sage", + "flex_attention": "flex_attention", + "flash_attention": "flash_attention_2", + "sdp_attention": "sdpa", + "eager_attention": "eager", +} + +# Short-form aliases rejected at validation; mapped to canonical names for error messages. +SHORT_FORM_ALIAS_TO_CANONICAL = { + "flash": "flash_attention_2", + "flex": "flex_attention", + "sdp": "sdpa", +} + +# Backends that support varlen sample packing via `position_ids`. +ATTN_IMPLS_SUPPORTING_PACKING = frozenset( + { + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + "kernels-community/flash-attn2", + "kernels-community/flash-attn3", + "kernels-community/sage-attention", + } +) + +# Backends that require the flash_attn library for axolotl's own monkeypatches. +ATTN_IMPLS_USING_FLASH_LIB = frozenset( + { + "flash_attention_2", + "flash_attention_3", + "s2", + "kernels-community/flash-attn2", + "kernels-community/flash-attn3", + } +) + +# Backends for which embeddings stay in fp32. Everything else needs fp16/bf16. +ATTN_IMPLS_WITHOUT_DTYPE_CAST = frozenset({"eager", "sdpa"}) + + class RingAttnFunc(str, Enum): """Enum class for supported `ring-flash-attn` implementations""" diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 76b09bfdb..ec11d9658 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -12,7 +12,11 @@ from pydantic import ( from transformers.utils.import_utils import is_torch_npu_available from axolotl.utils.logging import get_logger -from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType +from axolotl.utils.schemas.enums import ( + ChatTemplate, + RingAttnFunc, + RLType, +) LOG = get_logger(__name__) @@ -179,58 +183,42 @@ class DatasetValidationMixin: class AttentionValidationMixin: """Validation methods related to attention mechanisms.""" - @model_validator(mode="before") - @classmethod - def check_attention_fields(cls, data): - fields = ( - "xformers_attention", - "sdp_attention", - # "s2_attention", # requires both FA and this to be enabled - "flash_attention", - "flex_attention", - "sage_attention", - ) - non_empty_count = sum(1 for field in fields if data.get(field)) + @model_validator(mode="after") + def check_sample_packing_without_attention(self): + if self.sample_packing and not self.attn_supports_packing: + if self.attn_implementation: + LOG.warning( + "`sample_packing` with `attn_implementation=%r` does not handle " + "cross-sample decontamination. Use a varlen-capable backend " + "(e.g. flash_attention_2, flex_attention, xformers, sage) to " + "isolate samples.", + self.attn_implementation, + ) + else: + LOG.warning( + "`sample_packing` without an attention backend does not handle " + "cross-sample decontamination. Set `attn_implementation` to a " + "varlen-capable backend (e.g. flash_attention_2)." + ) + return self - if non_empty_count > 1: - raise ValueError(f"Only one of {', '.join(fields)} must be set") - return data - - @model_validator(mode="before") - @classmethod - def check_sample_packing_without_attention(cls, data): - if ( - data.get("sample_packing") - and not data.get("flash_attention") - and not data.get("sdp_attention") - and not data.get("flex_attention") - and not data.get("xformers_attention") - and not data.get("sage_attention") - ): - LOG.warning( - "sample_packing without flash, sdp, xformers, sage, or flex attention does not handle cross sample decontamination." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_sample_packing_with_s2attn(cls, data): - if data.get("sample_packing") and data.get("s2_attention"): + @model_validator(mode="after") + def check_sample_packing_with_s2attn(self): + if self.sample_packing and self.attn_implementation == "s2": raise ValueError( - "Received `sample_packing=true` and `s2_attention=true`; however, \ - shifted-sparse attention does not currently support sample packing." + "Received `sample_packing=true` and `attn_implementation=s2`; " + "shifted-sparse attention does not currently support sample packing." ) - return data + return self - @model_validator(mode="before") - @classmethod - def check_scaling_softmax_requires_flex(cls, data): - if data.get("scaling_softmax") and not data.get("flex_attention"): + @model_validator(mode="after") + def check_scaling_softmax_requires_flex(self): + if self.scaling_softmax and self.attn_implementation != "flex_attention": raise ValueError( - "scaling_softmax requires flex_attention: true\n" - "Add 'flex_attention: true' to your config file.\n" + "scaling_softmax requires flex attention. " + "Add `attn_implementation: flex_attention` to your config." ) - return data + return self class TrainingValidationMixin: @@ -431,7 +419,7 @@ class TrainingValidationMixin: not (self.bf16 or self.bfloat16) and (self.fp16 or self.float16) and not self.adapter - and not self.flash_attention + and not self.attn_uses_flash_lib and self.sample_packing ): LOG.warning( @@ -942,40 +930,45 @@ class OptimizationValidationMixin: ) return data - @model_validator(mode="before") - @classmethod - def check_batch_flattening_fa(cls, data): - if data.get("batch_flattening"): - batch_flattening_auto = data.get("batch_flattening") == "auto" - if not data.get("flash_attention") and not batch_flattening_auto: - raise ValueError("batch_flattening requires flash attention") - if data.get("sample_packing") and not batch_flattening_auto: - raise ValueError("batch_flattening not compatible with sample_packing") - if data.get("micro_batch_size") == 1 and not batch_flattening_auto: - LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + @model_validator(mode="after") + def check_batch_flattening_fa(self): + if not self.batch_flattening: + return self - # Liger loss takes a separate code path (compute_liger_loss) that - # bypasses the flattened training forward pass. Batch flattening - # still applies to the scoring/deferred logprobs path. - trl_cfg = data.get("trl") or {} - if isinstance(trl_cfg, dict) and trl_cfg.get("use_liger_loss"): - LOG.warning( - "batch_flattening with use_liger_loss: flattening will only " - "apply to the scoring path (deferred logprobs). The training " - "forward pass uses Liger's fused lm_head+loss kernel instead." - ) + batch_flattening_auto = self.batch_flattening == "auto" + has_varlen_attn = self.attn_supports_packing - if ( - batch_flattening_auto - and data.get("flash_attention") - and not data.get("sample_packing") - and data.get("micro_batch_size") > 1 - ): - data["batch_flattening"] = True - elif batch_flattening_auto: - data["batch_flattening"] = False + if not has_varlen_attn and not batch_flattening_auto: + raise ValueError( + "batch_flattening requires a varlen-capable attention backend " + "(e.g., attn_implementation: flash_attention_2)." + ) + if self.sample_packing and not batch_flattening_auto: + raise ValueError("batch_flattening not compatible with sample_packing") + if self.micro_batch_size == 1 and not batch_flattening_auto: + LOG.warning("batch_flattening has no effect with micro_batch_size == 1") - return data + # Liger loss takes a separate code path (compute_liger_loss) that + # bypasses the flattened training forward pass. Batch flattening + # still applies to the scoring/deferred logprobs path. + if self.trl and getattr(self.trl, "use_liger_loss", False): + LOG.warning( + "batch_flattening with use_liger_loss: flattening will only " + "apply to the scoring path (deferred logprobs). The training " + "forward pass uses Liger's fused lm_head+loss kernel instead." + ) + + if ( + batch_flattening_auto + and has_varlen_attn + and not self.sample_packing + and self.micro_batch_size > 1 + ): + self.batch_flattening = True + elif batch_flattening_auto: + self.batch_flattening = False + + return self @model_validator(mode="before") @classmethod @@ -1212,6 +1205,18 @@ class SystemValidationMixin: def check_npu_config(cls, data): if is_torch_npu_available(): # check attention config + unsupported_npu_impls = { + "flash_attention_2", + "flash_attention_3", + "sdpa", + "s2", + } + attn_impl = data.get("attn_implementation") + if attn_impl and attn_impl in unsupported_npu_impls: + raise NotImplementedError( + f"attn_implementation={attn_impl!r} is currently not supported on Ascend NPU." + ) + # Legacy flags still present at this point (normalizer strips them later). attn_list = ["flash_attention", "sdp_attention", "s2_attention"] for attn in attn_list: if data.get(attn): @@ -1520,9 +1525,10 @@ class ComplexValidationMixin: if not self.context_parallel_size: self.context_parallel_size = 1 elif self.context_parallel_size > 1: - if not self.flash_attention: + if not self.attn_uses_flash_lib: raise ValueError( - "flash_attention: true must be set with context_parallel_size > 1" + "context_parallel_size > 1 requires flash attention " + "(attn_implementation: flash or s2)." ) if self.sample_packing and self.micro_batch_size > 1: @@ -1652,47 +1658,46 @@ class EBFTValidationMixin: ) return data - @model_validator(mode="before") - @classmethod - def check_ebft_gradient_checkpointing_reentrant(cls, data): + @model_validator(mode="after") + def check_ebft_gradient_checkpointing_reentrant(self): """flex_attention + non-reentrant gradient checkpointing causes CheckpointError.""" if ( - data.get("rl") == "ebft" - and data.get("ebft", {}).get("mode") == "strided" - and data.get("flex_attention") - and data.get("gradient_checkpointing") + self.rl == "ebft" + and (self.ebft or {}).get("mode") == "strided" + and self.attn_implementation == "flex_attention" + and self.gradient_checkpointing ): - gc_kwargs = data.get("gradient_checkpointing_kwargs") or {} + gc_kwargs = self.gradient_checkpointing_kwargs or {} if not gc_kwargs.get("use_reentrant"): LOG.warning( "EBFT strided mode with flex_attention: setting `use_reentrant: true` in " "gradient_checkpointing_kwargs (required for flex_attention compatibility). " "Non-reentrant checkpointing causes CheckpointError with BlockMask metadata." ) - if data.get("gradient_checkpointing_kwargs") is None: - data["gradient_checkpointing_kwargs"] = {} - data["gradient_checkpointing_kwargs"]["use_reentrant"] = True - return data + if self.gradient_checkpointing_kwargs is None: + self.gradient_checkpointing_kwargs = {} + self.gradient_checkpointing_kwargs["use_reentrant"] = True + return self - @model_validator(mode="before") - @classmethod - def check_ebft_activation_offloading(cls, data): + @model_validator(mode="after") + def check_ebft_activation_offloading(self): """activation_offloading replaces gradient checkpointing with FSDP-style wrapping, which conflicts with flex_attention's use_reentrant requirement.""" if ( - data.get("rl") == "ebft" - and data.get("ebft", {}).get("mode") == "strided" - and data.get("activation_offloading") is True - and data.get("flex_attention") + self.rl == "ebft" + and (self.ebft or {}).get("mode") == "strided" + and self.activation_offloading is True + and self.attn_implementation == "flex_attention" ): raise ValueError( "EBFT strided mode: `activation_offloading: true` is incompatible with " - "`flex_attention: true`. Activation offloading replaces gradient checkpointing " - "with FSDP-style wrapping that conflicts with flex_attention's reentrant " - "checkpoint requirement. Remove `activation_offloading` — the strided trainer " - "uses micro-batched forward passes for memory efficiency instead." + "`attn_implementation: flex_attention`. Activation offloading replaces " + "gradient checkpointing with FSDP-style wrapping that conflicts with " + "flex_attention's reentrant checkpoint requirement. Remove " + "`activation_offloading` — the strided trainer uses micro-batched forward " + "passes for memory efficiency instead." ) - return data + return self @model_validator(mode="before") @classmethod diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 91982137b..3fb940364 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -462,7 +462,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" ) else: - if cfg.flash_attention and not cfg.multipack_real_batches: + if cfg.attn_supports_packing and not cfg.multipack_real_batches: sampler_batch_size = 1 batch_max_len = cfg.micro_batch_size * cfg.sequence_len else: diff --git a/tests/e2e/multigpu/test_llama.py b/tests/e2e/multigpu/test_llama.py index 1e3757dcf..b89c93522 100644 --- a/tests/e2e/multigpu/test_llama.py +++ b/tests/e2e/multigpu/test_llama.py @@ -521,9 +521,9 @@ class TestMultiGPULlama: } ) if attention_backend == "flash": - cfg.flash_attention = True + cfg.attn_implementation = "flash_attention_2" elif attention_backend == "flex": - cfg.flex_attention = True + cfg.attn_implementation = "flex_attention" # write cfg to yaml file Path(temp_dir).mkdir(parents=True, exist_ok=True) diff --git a/tests/test_attn_implementation.py b/tests/test_attn_implementation.py new file mode 100644 index 000000000..e85d713a8 --- /dev/null +++ b/tests/test_attn_implementation.py @@ -0,0 +1,418 @@ +"""Tests for attn_implementation: normalization, canonical-value acceptance, +capability flags, backend registration, and downstream validators. +""" + +import logging +from contextlib import contextmanager + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault +from axolotl.utils.schemas.config import AxolotlInputConfig +from axolotl.utils.schemas.enums import ( + ATTN_IMPLS_SUPPORTING_PACKING, + ATTN_IMPLS_USING_FLASH_LIB, + ATTN_IMPLS_WITHOUT_DTYPE_CAST, + CANONICAL_ATTN_IMPLS, +) + + +@contextmanager +def _capture_axolotl_warnings(caplog): + """Capture WARNINGs from `axolotl.*` loggers via caplog. + + `axolotl.cli` calls `configure_logging()` at import time, which sets + `propagate=False` on the `axolotl` logger so records do not reach the root + logger that pytest's `caplog` hooks. This helper temporarily re-enables + propagation for the duration of the block. + """ + ax_logger = logging.getLogger("axolotl") + old_propagate = ax_logger.propagate + ax_logger.propagate = True + try: + with caplog.at_level(logging.WARNING, logger="axolotl"): + yield + finally: + ax_logger.propagate = old_propagate + + +def _xformers_available(): + try: + import xformers.ops # noqa: F401 + + return True + except (ImportError, OSError): + return False + + +class TestCapabilityTables: + """Backend capability classification via frozensets and computed_field properties.""" + + @pytest.mark.parametrize( + "impl", + [ + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + ], + ) + def test_supports_packing(self, impl): + assert impl in ATTN_IMPLS_SUPPORTING_PACKING + + @pytest.mark.parametrize("impl", ["eager", "sdpa", "s2", "fp8"]) + def test_does_not_support_packing(self, impl): + assert impl not in ATTN_IMPLS_SUPPORTING_PACKING + + @pytest.mark.parametrize("impl", ["flash_attention_2", "flash_attention_3", "s2"]) + def test_uses_flash_lib(self, impl): + assert impl in ATTN_IMPLS_USING_FLASH_LIB + + @pytest.mark.parametrize( + "impl", ["eager", "sdpa", "xformers", "flex_attention", "sage", "fp8"] + ) + def test_does_not_use_flash_lib(self, impl): + assert impl not in ATTN_IMPLS_USING_FLASH_LIB + + @pytest.mark.parametrize("impl", ["eager", "sdpa"]) + def test_no_dtype_cast(self, impl): + assert impl in ATTN_IMPLS_WITHOUT_DTYPE_CAST + + @pytest.mark.parametrize( + "impl", + [ + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "xformers", + "sage", + "s2", + "fp8", + ], + ) + def test_needs_dtype_cast(self, impl): + assert impl not in ATTN_IMPLS_WITHOUT_DTYPE_CAST + + def test_known_hub_kernels_classified(self): + assert "kernels-community/flash-attn3" in ATTN_IMPLS_SUPPORTING_PACKING + assert "kernels-community/flash-attn3" in ATTN_IMPLS_USING_FLASH_LIB + assert "kernels-community/sage-attention" in ATTN_IMPLS_SUPPORTING_PACKING + + def test_computed_flags_readable_on_validated_cfg(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(attn_implementation="sdpa") + validated = validate_config(cfg) + assert validated.attn_implementation == "sdpa" + assert validated.attn_supports_packing is False + assert validated.attn_uses_flash_lib is False + assert validated.attn_needs_dtype_cast is False + + def test_computed_flags_not_overridable_from_yaml(self, min_base_cfg): + """YAML attempts to override a computed field must not win.""" + cfg = min_base_cfg | DictDefault( + attn_implementation="eager", attn_uses_flash_lib=True + ) + validated = validate_config(cfg) + # The computed field reflects the backend, not the YAML input. + assert validated.attn_uses_flash_lib is False + + +class TestBackendRegistration: + """Axolotl-owned backends register under their canonical names in HF's registries.""" + + @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") + def test_register_xformers(self): + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from axolotl.monkeypatch.attention import register_xformers_attn + + register_xformers_attn() + + assert "xformers" in ALL_ATTENTION_FUNCTIONS + assert "xformers" in ALL_MASK_ATTENTION_FUNCTIONS + assert ( + ALL_MASK_ATTENTION_FUNCTIONS["xformers"] + == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) + + def test_register_sage(self): + from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from axolotl.monkeypatch.attention import register_sage_attn + + register_sage_attn() + + assert "sage" in ALL_ATTENTION_FUNCTIONS + assert "sage" in ALL_MASK_ATTENTION_FUNCTIONS + assert ( + ALL_MASK_ATTENTION_FUNCTIONS["sage"] + == ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] + ) + + @pytest.mark.skipif(not _xformers_available(), reason="xformers not available") + def test_xformers_does_not_overwrite_fa2(self): + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] + + from axolotl.monkeypatch.attention import register_xformers_attn + + register_xformers_attn() + + assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 + + def test_sage_does_not_overwrite_fa2(self): + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + original_fa2 = ALL_ATTENTION_FUNCTIONS["flash_attention_2"] + + from axolotl.monkeypatch.attention import register_sage_attn + + register_sage_attn() + + assert ALL_ATTENTION_FUNCTIONS["flash_attention_2"] is original_fa2 + + +class TestLegacyFlagDeprecation: + """Legacy boolean flags (flash_attention, sdp_attention, ...) map to a + canonical attn_implementation value, are stripped from the validated + config, and cannot be combined with an explicit canonical value. + """ + + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + @pytest.mark.parametrize( + "flag,expected", + [ + ("flash_attention", "flash_attention_2"), + ("sdp_attention", "sdpa"), + ("xformers_attention", "xformers"), + ("flex_attention", "flex_attention"), + ("sage_attention", "sage"), + ("eager_attention", "eager"), + ("s2_attention", "s2"), + ], + ) + def test_legacy_flag_maps_to_canonical(self, flag, expected): + result = self._normalize({flag: True}) + assert result["attn_implementation"] == expected + + def test_legacy_flags_are_stripped_after_mapping(self): + result = self._normalize({"flash_attention": True}) + for flag in [ + "flash_attention", + "sdp_attention", + "xformers_attention", + "flex_attention", + "sage_attention", + "eager_attention", + "s2_attention", + ]: + assert flag not in result + + def test_s2_plus_flash_priority_is_s2(self): + result = self._normalize({"s2_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "s2" + + def test_sage_plus_flash_priority_is_sage(self): + result = self._normalize({"sage_attention": True, "flash_attention": True}) + assert result["attn_implementation"] == "sage" + + def test_canonical_plus_legacy_flag_raises(self): + with pytest.raises(ValueError, match="cannot be combined with legacy"): + self._normalize( + {"attn_implementation": "flash_attention_2", "flash_attention": True} + ) + + def test_canonical_plus_unrelated_legacy_flag_raises(self): + with pytest.raises(ValueError, match="cannot be combined with legacy"): + self._normalize( + {"attn_implementation": "xformers", "flash_attention": True} + ) + + def test_legacy_flag_stripped_on_validated_cfg(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(flash_attention=True) + validated = validate_config(cfg) + assert validated.attn_implementation == "flash_attention_2" + # Legacy flag must not survive to the validated DictDefault + # (normalizer pops it, model_dump excludes Nones). + assert "flash_attention" not in dict(validated) + + def test_canonical_plus_legacy_rejected_on_full_validation(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + attn_implementation="flash_attention_2", flash_attention=True + ) + with pytest.raises(ValueError, match="cannot be combined with legacy"): + validate_config(cfg) + + def test_s2_plus_flash_maps_to_s2_on_full_validation(self, min_base_cfg): + """Priority resolution applies through the full validator chain too.""" + cfg = min_base_cfg | DictDefault(s2_attention=True, flash_attention=True) + validated = validate_config(cfg) + assert validated.attn_implementation == "s2" + + +class TestCanonicalValueAcceptance: + """`attn_implementation` accepts canonical names and `org/name` hub-kernel + paths. Short-form aliases (`flash`, `flex`, `sdp`) and unknown bare names + are rejected. Absent input is a noop. + """ + + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + def test_canonical_value_is_passthrough(self): + data = {"attn_implementation": "flash_attention_2"} + result = self._normalize(data) + assert result["attn_implementation"] == "flash_attention_2" + + def test_hub_kernel_is_passthrough(self): + data = {"attn_implementation": "kernels-community/flash-attn3"} + result = self._normalize(data) + assert result["attn_implementation"] == "kernels-community/flash-attn3" + + def test_no_attention_set_is_noop(self): + result = self._normalize({"some_other_config": True}) + assert result.get("attn_implementation") is None + + def test_field_validator_accepts_all_canonical(self): + for impl in CANONICAL_ATTN_IMPLS: + assert AxolotlInputConfig.validate_attn_implementation(impl) == impl + + def test_field_validator_accepts_hub_kernels(self): + for impl in ( + "kernels-community/flash-attn3", + "kernels-community/sage-attention", + "someorg/custom-kernel", + ): + assert AxolotlInputConfig.validate_attn_implementation(impl) == impl + + def test_field_validator_accepts_none(self): + assert AxolotlInputConfig.validate_attn_implementation(None) is None + + @pytest.mark.parametrize("alias", ["flash", "flex", "sdp"]) + def test_short_form_alias_rejected(self, alias): + with pytest.raises(ValueError, match="is not accepted"): + AxolotlInputConfig.validate_attn_implementation(alias) + + def test_unknown_bare_name_rejected(self): + with pytest.raises(ValueError, match="not a recognized backend"): + AxolotlInputConfig.validate_attn_implementation("not_a_real_backend") + + def test_canonical_value_passes_through_full_validation(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(attn_implementation="flash_attention_3") + validated = validate_config(cfg) + assert validated.attn_implementation == "flash_attention_3" + assert validated.attn_uses_flash_lib is True + assert validated.attn_supports_packing is True + + def test_hub_kernel_passes_through_full_validation(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + attn_implementation="kernels-community/flash-attn3" + ) + validated = validate_config(cfg) + assert validated.attn_implementation == "kernels-community/flash-attn3" + assert validated.attn_uses_flash_lib is True + assert validated.attn_supports_packing is True + + def test_short_form_alias_rejected_on_full_validation(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(attn_implementation="flash") + with pytest.raises(ValueError, match="is not accepted"): + validate_config(cfg) + + +class TestGemma4HybridMode: + """`gemma4_hybrid_attn_impl` pins `attn_implementation` to `flash_attention_2`.""" + + @staticmethod + def _normalize(data): + return AxolotlInputConfig.normalize_attn_implementation(data) + + def test_defaults_to_flash_attention_2(self): + result = self._normalize({"gemma4_hybrid_attn_impl": True}) + assert result["attn_implementation"] == "flash_attention_2" + + def test_explicit_fa2_passes(self): + result = self._normalize( + { + "gemma4_hybrid_attn_impl": True, + "attn_implementation": "flash_attention_2", + } + ) + assert result["attn_implementation"] == "flash_attention_2" + + def test_non_fa2_raises(self): + with pytest.raises( + ValueError, match="requires attn_implementation=flash_attention_2" + ): + self._normalize( + {"gemma4_hybrid_attn_impl": True, "attn_implementation": "sdpa"} + ) + + +class TestSamplePackingValidation: + """`sample_packing` warns for non-varlen backends; s2 raises outright.""" + + def test_eager_warns(self, min_base_cfg, caplog): + cfg = min_base_cfg | DictDefault( + attn_implementation="eager", sample_packing=True + ) + with _capture_axolotl_warnings(caplog): + validate_config(cfg) + assert any( + "does not handle cross-sample decontamination" in r.getMessage() + for r in caplog.records + ) + + def test_sdpa_warns(self, min_base_cfg, caplog): + cfg = min_base_cfg | DictDefault( + attn_implementation="sdpa", sample_packing=True + ) + with _capture_axolotl_warnings(caplog): + validate_config(cfg) + assert any( + "does not handle cross-sample decontamination" in r.getMessage() + for r in caplog.records + ) + + def test_flash_attention_2_does_not_warn(self, min_base_cfg, caplog): + cfg = min_base_cfg | DictDefault( + attn_implementation="flash_attention_2", sample_packing=True + ) + with _capture_axolotl_warnings(caplog): + validate_config(cfg) + assert not any( + "does not handle cross-sample decontamination" in r.getMessage() + for r in caplog.records + ) + + def test_s2_raises(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(attn_implementation="s2", sample_packing=True) + with pytest.raises( + ValueError, match="shifted-sparse attention does not currently support" + ): + validate_config(cfg) + + +class TestScalingSoftmaxValidation: + """`scaling_softmax` is only implemented under flex_attention.""" + + def test_non_flex_raises(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + attn_implementation="flash_attention_2", scaling_softmax=True + ) + with pytest.raises(ValueError, match="scaling_softmax requires flex"): + validate_config(cfg) + + def test_flex_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + attn_implementation="flex_attention", scaling_softmax=True + ) + validated = validate_config(cfg) + assert validated.attn_implementation == "flex_attention" diff --git a/tests/test_mm_chat_collator.py b/tests/test_mm_chat_collator.py new file mode 100644 index 000000000..d26a011a7 --- /dev/null +++ b/tests/test_mm_chat_collator.py @@ -0,0 +1,163 @@ +""" +Regression tests for MultiModalChatDataCollator shape contracts. + +Guard against the transformers 5.x breakage where apply_chat_template's +own `return_dict` parameter (default False) caused it to return the raw +input_ids tensor instead of the full BatchFeature dict, leading to + IndexError: too many indices for tensor of dimension 2 +when downstream code did batch["input_ids"] on the resulting tensor. +""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from transformers import BatchFeature + + +@pytest.fixture(name="mock_processor") +def fixture_mock_processor(): + """ + A mock processor whose apply_chat_template returns a BatchFeature + when called with return_dict=True (the correct call convention), + or a raw input_ids tensor when called without return_dict=True + (the broken call convention that the bug introduced). + """ + processor = MagicMock() + processor.tokenizer = MagicMock() + processor.tokenizer.pad_token_id = 0 + processor.image_token = "<|image|>" + processor.tokenizer.convert_tokens_to_ids = MagicMock(return_value=128256) + + batch_size, seq_len = 2, 16 + input_ids = torch.ones(batch_size, seq_len, dtype=torch.long) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long) + + batch_feature = BatchFeature( + data={ + "input_ids": input_ids, + "attention_mask": attention_mask, + } + ) + + def _apply_chat_template(*args, **kwargs): + if kwargs.get("return_dict", False): + return batch_feature + # Simulate transformers 5.x default behaviour: returns out["input_ids"] + return input_ids + + processor.apply_chat_template = MagicMock(side_effect=_apply_chat_template) + processor.chat_template = None + return processor + + +@pytest.fixture(name="mock_processing_strategy") +def fixture_mock_processing_strategy(mock_processor): + from axolotl.processing_strategies import ProcessingStrategy + + strategy = ProcessingStrategy(processor=mock_processor) + return strategy + + +class TestMultiModalChatDataCollatorShapeContract: + """ + Verify that MultiModalChatDataCollator.process_rows returns a dict with + 2-D input_ids and labels, not a raw tensor. This is the shape contract + that process_labels depends on. + """ + + def _make_collator(self, mock_processing_strategy): + from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator + + tokenizer = mock_processing_strategy.processor.tokenizer + return MultiModalChatDataCollator( + tokenizer=tokenizer, + processing_strategy=mock_processing_strategy, + ) + + def _make_examples(self): + return [ + { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ] + } + ] + + def test_process_rows_returns_dict(self, mock_processing_strategy): + """batch must be a dict, not a raw tensor.""" + collator = self._make_collator(mock_processing_strategy) + examples = self._make_examples() + + with patch.object( + mock_processing_strategy, + "__call__", + return_value=examples, + ): + batch = collator.process_rows(examples) + + assert isinstance(batch, dict), ( + "process_rows must return a dict (BatchFeature), not a raw tensor. " + "If it returns a tensor, apply_chat_template was called without " + "return_dict=True at the top level." + ) + + def test_process_rows_input_ids_shape(self, mock_processing_strategy): + """batch['input_ids'] must be a 2-D tensor (batch, seq_len).""" + collator = self._make_collator(mock_processing_strategy) + examples = self._make_examples() + + with patch.object( + mock_processing_strategy, + "__call__", + return_value=examples, + ): + batch = collator.process_rows(examples) + + assert "input_ids" in batch + assert isinstance(batch["input_ids"], torch.Tensor) + assert batch["input_ids"].ndim == 2, ( + f"input_ids must be 2-D (batch, seq_len), got shape {batch['input_ids'].shape}" + ) + + def test_process_rows_labels_shape(self, mock_processing_strategy): + """batch['labels'] must be a 2-D tensor matching input_ids shape.""" + collator = self._make_collator(mock_processing_strategy) + examples = self._make_examples() + + with patch.object( + mock_processing_strategy, + "__call__", + return_value=examples, + ): + batch = collator.process_rows(examples) + + assert "labels" in batch + assert isinstance(batch["labels"], torch.Tensor) + assert batch["labels"].ndim == 2 + assert batch["labels"].shape == batch["input_ids"].shape + + def test_apply_chat_template_called_with_return_dict_true( + self, mock_processing_strategy + ): + """apply_chat_template must be called with return_dict=True as a keyword arg.""" + collator = self._make_collator(mock_processing_strategy) + examples = self._make_examples() + + with patch.object( + mock_processing_strategy, + "__call__", + return_value=examples, + ): + collator.process_rows(examples) + + call_kwargs = ( + mock_processing_strategy.processor.apply_chat_template.call_args.kwargs + ) + assert call_kwargs.get("return_dict") is True, ( + "apply_chat_template must be called with return_dict=True as a top-level " + "keyword argument (not inside processor_kwargs). In transformers 5.x, " + "apply_chat_template has its own return_dict param (default False) that " + "controls whether it returns the full BatchFeature or just input_ids." + ) diff --git a/tests/test_no_legacy_attn_reads.py b/tests/test_no_legacy_attn_reads.py new file mode 100644 index 000000000..2435f9fa8 --- /dev/null +++ b/tests/test_no_legacy_attn_reads.py @@ -0,0 +1,62 @@ +"""Enforce attn_implementation as the single source of truth. + +Fails if src/ contains a cfg._attention read. Migrate offending sites +to cfg.attn_implementation or the attn_supports_packing/attn_uses_flash_lib/ +attn_needs_dtype_cast computed flags. +""" + +from __future__ import annotations + +import re +from pathlib import Path + +LEGACY_FLAGS = ( + "flash_attention", + "sdp_attention", + "xformers_attention", + "flex_attention", + "sage_attention", + "s2_attention", + "eager_attention", +) + +# The normalizer is allowed to read the legacy keys (that's its job). +# lm_eval/cli.py is a raw-YAML entry point (bypasses AxolotlInputConfig) that +# honors both forms during the deprecation period — when we remove the legacy +# flags entirely, drop this allowlist entry and the BC branch in that file. +ALLOWED_FILES = { + Path("src/axolotl/utils/schemas/config.py"), + Path("src/axolotl/integrations/lm_eval/cli.py"), +} + +# `cfg.`, `self.cfg.`, `data.get("")`, `data[""]` +_PATTERNS = [re.compile(rf"\bcfg\.{flag}\b") for flag in LEGACY_FLAGS] + [ + re.compile(rf'\bdata\.get\("{flag}"\)') for flag in LEGACY_FLAGS +] + + +def _repo_root() -> Path: + return Path(__file__).resolve().parent.parent + + +def test_no_legacy_attn_reads_in_src(): + root = _repo_root() + src = root / "src" + offenders: list[str] = [] + + for py_file in src.rglob("*.py"): + rel = py_file.relative_to(root) + if rel in ALLOWED_FILES: + continue + text = py_file.read_text(encoding="utf-8") + for pattern in _PATTERNS: + for match in pattern.finditer(text): + # Line number for the user's convenience. + line_no = text.count("\n", 0, match.start()) + 1 + offenders.append(f"{rel}:{line_no} {match.group(0)}") + + assert not offenders, ( + "Found legacy attention-flag reads in src/. Migrate to " + "`cfg.attn_implementation` / capability flags:\n " + + "\n ".join(sorted(offenders)) + )