Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)

* upgrade to torchao 0.17.0

* chore: lint

* refactor attention handling

* replace legacy attention boolean flags with capability properties

Replace checks with capability-based properties derived from attn_implementation

This separates three concerns that were conflated under flash_attention:
1. Backend selection -> attn_implementation enum
2. Packing capability -> attn_supports_packing property
3. Flash-attn library dependency -> attn_uses_flash_lib property

* compute attn capability flags in normalizer instead of properties

* make attn_implementation the single source of truth

* move attention-dependent validators to mode=after

* migrate remaining consumers to canonical attn_implementation

* expand attention tests + rewrite docs

* migrate example configs to canonical attn_implementation

* update doc snippets + reject gemma4-hybrid with non-FA2 backend

* remove dead gemma4 branch in _set_attention_config

* fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests

* drop "Phase 2" naming from attn-implementation tests

* regroup attn_implementation tests by feature concern

* clean up verbose comments and remove MD

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x

In transformers 5.x, ProcessorMixin.apply_chat_template gained its own
`return_dict` parameter (defaulting to False).  When return_dict=False
and tokenize=True the method returns out["input_ids"] directly — a 2-D
tensor — rather than the full BatchFeature dict.

The old code placed `return_dict=True` inside processor_kwargs.  In
transformers 5.x those kwargs are forwarded to the underlying processor
call self(...) where _merge_kwargs silently ignores any key not present
in MllamaProcessorKwargs (emitting a warning).  The outer return_dict
therefore stayed False, apply_chat_template returned the raw input_ids
tensor, and the subsequent `batch["input_ids"]` attempted to index a
2-D tensor with the 9-character string "input_ids", producing:

  IndexError: too many indices for tensor of dimension 2

The fix is to pass return_dict=True as a top-level keyword argument to
apply_chat_template (where it is actually consumed) and remove it from
processor_kwargs (where it was silently dropped).  No version guard is
needed: transformers is pinned to ==5.5.4 in pyproject.toml.

Adds a unit-level regression test (tests/test_mm_chat_collator.py) that
mocks the processor to return a raw tensor when apply_chat_template is
called without top-level return_dict=True, verifying the four invariants:
process_rows returns a dict, input_ids is 2-D, labels is 2-D, and
apply_chat_template receives return_dict=True as a top-level kwarg.

Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset
Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset
Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* fix(collator): process_rows returns dict (BatchFeature) shape

Two related changes for the multimodal chat collator under transformers 5.x:

1. Wrap apply_chat_template result in dict(...) so process_rows returns
   a plain dict rather than a BatchFeature instance. BatchFeature is a
   Mapping but not a dict; downstream code that did
     batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
   would index on a tensor when the result wasn't dict-shaped, raising
     IndexError: too many indices for tensor of dimension 2

2. Soften the regression test's contract from `dict` to `Mapping` so it
   exercises the actual semantic guarantee (key/value access) rather
   than the implementation detail (dict vs BatchFeature). Test guards
   against the original transformers 5.x breakage where apply_chat_template's
   return_dict default went from True to False.

Includes regression test under tests/test_mm_chat_collator.py.

Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against
attn-implementation-refactor; squash-merged from agent commits 4de886fd
+ dc9fcf4f.

Signed-off-by: Wing Lian <wing@axolotl.ai>

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
Wing Lian
2026-05-05 10:15:18 -04:00
committed by GitHub
parent 6136ae627b
commit e4032fc90f
252 changed files with 1502 additions and 572 deletions

View File

@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works |
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.

View File

@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| Issue | Fix |
|-------|-----|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
| Missing chat template error | Set `chat_template: chatml` explicitly |
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |

View File

@@ -3,28 +3,71 @@ title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
Axolotl routes attention via a single config field:
```yaml
sdp_attention: true
attn_implementation: <backend>
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
`attn_implementation` is passed through to `transformers` verbatim (via
`model.config._attn_implementation`). Accepted values are the HF-native
backends, axolotl-registered backends, or a hub-kernel path.
## Flash Attention
## Backends
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
| `attn_implementation` | Description |
|---|---|
| `eager` | Plain PyTorch attention. No packing support. |
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
| `xformers` | xFormers memory-efficient attention. |
| `sage` | SageAttention (QK int8 / PV fp16). |
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
set the canonical name above.
### Capability flags
Axolotl derives three boolean capability flags from `attn_implementation` and
exposes them on the validated config:
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
(everything except `eager` and `sdpa`).
These are **computed** — they cannot be overridden from YAML.
## Per-backend notes
### SDPA
Default PyTorch attention. See
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
```yaml
flash_attention: true
attn_implementation: sdpa
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Flash Attention
### Flash Attention 2
Axolotl supports FA2, FA3, and FA4. The best available version is used
automatically based on your installed packages and GPU.
```yaml
attn_implementation: flash_attention_2 # or flash_attention_3
```
#### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
@@ -39,20 +82,20 @@ Alternatively, try reinstall or downgrade a version.
:::
### Flash Attention 3
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
#### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
is true and FA4 is importable.
FA4 is still a pre-release on PyPI, so `--pre` is required:
@@ -65,7 +108,6 @@ Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
@@ -88,93 +130,113 @@ and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.
Requirements: ROCm 6.0 and above. See
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
### Flex Attention
```yaml
flex_attention: true
# recommended
torch_compile: true
attn_implementation: flex_attention
torch_compile: true # recommended
```
::: {.callout-note}
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
We recommend using latest stable version of PyTorch for best performance.
### SageAttention
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
Requirements: Ampere, Ada, or Hopper GPUs.
```yaml
sage_attention: true
attn_implementation: sage
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
### xFormers
```yaml
xformers_attention: true
attn_implementation: xformers
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
Recommended for Turing GPUs or below (e.g. Colab T4).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
### Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
Planned for deprecation. Prefer one of the backends above.
:::
Requirements: LLaMA model architecture
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
patched to implement shifted-sparse attention. Does not support sample packing.
```yaml
flash_attention: true
s2_attention: true
attn_implementation: s2
```
::: {.callout-tip}
### FP8
No sample packing support!
torchao low-precision attention. Loaded as SDPA and patched post-load.
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
flash-attn with FA3. KV caching must be disabled.
```yaml
attn_implementation: fp8
```
### Hub kernels
```yaml
attn_implementation: kernels-community/flash-attn3
```
Passed through to `transformers`; axolotl does not install the kernel itself.
For recognized hub paths the capability flags are set automatically; for
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
`attn_uses_flash_lib=False`).
## Migrating from legacy boolean flags
The following legacy config fields are **deprecated** and will be removed in a
future release. Each emits a `DeprecationWarning` when set and is stripped from
the validated config.
| Legacy | Canonical |
|---|---|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
| `sdp_attention: true` | `attn_implementation: sdpa` |
| `xformers_attention: true` | `attn_implementation: xformers` |
| `flex_attention: true` | `attn_implementation: flex_attention` |
| `sage_attention: true` | `attn_implementation: sage` |
| `s2_attention: true` | `attn_implementation: s2` |
| `eager_attention: true` | `attn_implementation: eager` |
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
::: {.callout-note}
Existing example configs under `examples/` still use the legacy flags. They
continue to work with a deprecation warning; they will be migrated in a
follow-up pass.
:::

View File

@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
max_steps: 20
learning_rate: 5.0e-6
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstart
```
@@ -304,7 +304,7 @@ lora_alpha: 32
lora_target_linear: true
bf16: auto
flex_attention: true
attn_implementation: flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attention

View File

@@ -154,7 +154,7 @@ lr_scheduler: cosine
warmup_steps: 10
bf16: true
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
*Note: You should only enable one attention backend.*
See [Attention](attention.qmd) for the full list of backends and the canonical values.
### LoRA Optimizations

View File

@@ -1147,8 +1147,7 @@ datasets:
type: ebft_strided_structured.transform
split: train[:1%]
flash_attention: false
flex_attention: true # Strided mode uses flex_attention
attn_implementation: flex_attention # Strided mode uses flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required for flex_attention

View File

@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example

View File

@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
Reduces attention memory from O(n^2) to O(n):
```yaml
flash_attention: true
attn_implementation: flash_attention_2
```
### Step 6: Offload with DeepSpeed

View File

@@ -39,7 +39,7 @@ tf32: true
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -50,8 +50,7 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_steps: 100
saves_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_steps: 100
saves_per_epoch: 1

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -59,8 +59,7 @@ gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
sdp_attention:
attn_implementation: flash_attention_2
flash_optimum:
gptq_groupsize:

View File

@@ -39,8 +39,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -43,8 +43,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -73,8 +73,7 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -40,8 +40,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -36,8 +36,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -37,8 +37,7 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -39,7 +39,6 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -39,7 +39,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -40,7 +40,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,6 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -47,7 +47,6 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -40,7 +40,6 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -38,7 +38,6 @@ tf32: true
gradient_checkpointing:
resume_from_checkpoint:
logging_steps: 1
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -44,7 +44,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_mlp: true

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
flash_attn_cross_entropy: false
flash_attn_rms_norm: true

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -47,7 +47,6 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: false
warmup_ratio: 0.1
evals_per_epoch: 0

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -36,7 +36,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -71,8 +71,7 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
xformers_attention: true
flash_attention:
attn_implementation: xformers
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -10,7 +10,7 @@ load_in_4bit: true
sequence_len: 1024
bf16: auto
tf32: false
flash_attention: true
attn_implementation: flash_attention_2
special_tokens:
bos_token: "<|startoftext|>"
eos_token: "<|endoftext|>"

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -50,7 +50,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
# scaling_softmax: true # needs flex_attention
loss_watchdog_threshold: 5.0

View File

@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
flash_attention: true
attn_implementation: flash_attention_2
gradient_accumulation_steps: 1
micro_batch_size: 1

View File

@@ -26,7 +26,7 @@ output_dir: ./outputs/ndp-out/
sequence_len: 8192
sample_packing: true
flash_attention: true
attn_implementation: flash_attention_2
gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel

View File

@@ -65,8 +65,7 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
weight_decay: 0.0

View File

@@ -46,7 +46,7 @@ lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -66,7 +66,7 @@ lora_target_linear: true
# --- Hardware ---
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -47,8 +47,7 @@ lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
attn_implementation: flex_attention
# (full-model compile conflicts with gradient checkpointing + flex_attention)
gradient_checkpointing: true
gradient_checkpointing_kwargs:

View File

@@ -46,7 +46,6 @@ lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
gradient_checkpointing: true
special_tokens:

View File

@@ -48,7 +48,6 @@ lora_target_linear: true
bf16: auto
torch_dtype: bfloat16
flash_attention: false
gradient_checkpointing: true
torch_compile: true
gradient_checkpointing_kwargs:

View File

@@ -41,7 +41,6 @@ warmup_steps: 10
weight_decay: 0.01
bf16: auto
flash_attention: false # strided EBFT uses flex_attention at runtime
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false

View File

@@ -72,7 +72,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -63,7 +63,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -68,7 +68,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
flash_attention: true
attn_implementation: flash_attention_2
gradient_checkpointing: true
special_tokens:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -61,7 +61,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -53,7 +53,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -58,8 +58,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
eager_attention:
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -55,8 +55,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
eager_attention:
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -84,7 +84,7 @@ activation_offloading: true
logging_steps: 1
# FA2 not supported
sdp_attention: true
attn_implementation: sdpa
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -62,7 +62,7 @@ activation_offloading: true
logging_steps: 1
# FA not supported
flex_attention: true
attn_implementation: flex_attention
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -60,7 +60,7 @@ activation_offloading: true
logging_steps: 1
# FA not supported
sdp_attention: true
attn_implementation: sdpa
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -50,7 +50,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
attn_implementation: sdpa
warmup_ratio: 0.1
weight_decay: 0.0

View File

@@ -50,7 +50,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -45,7 +45,7 @@ gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
attn_implementation: sdpa
warmup_ratio: 0.1
evals_per_epoch: 0

View File

@@ -42,7 +42,7 @@ tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
attn_implementation: sdpa
warmup_ratio: 0.1
evals_per_epoch: 0

View File

@@ -58,7 +58,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -57,7 +57,7 @@ tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -58,7 +58,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -57,7 +57,7 @@ tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
attn_implementation: flash_attention_2
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -47,7 +47,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -43,7 +43,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -44,7 +44,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -43,7 +43,6 @@ learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

View File

@@ -56,7 +56,6 @@ learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true

Some files were not shown because too many files have changed in this diff Show More