Refactor separate attention flags with attn_implementation and capability/concerns feature flags (#3602)
* upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
|
||||
|
||||
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||
|---------|--------|---------------|---------------|-------|
|
||||
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
|
||||
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
|
||||
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||
| eager | neither set | None | ✅ | Slowest, always works |
|
||||
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
|
||||
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
|
||||
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
|
||||
|
||||
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
|
||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
|
||||
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
||||
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
||||
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
||||
|
||||
@@ -3,28 +3,71 @@ title: Attention
|
||||
description: Supported attention modules in Axolotl
|
||||
---
|
||||
|
||||
## SDP Attention
|
||||
|
||||
This is the default built-in attention in PyTorch.
|
||||
Axolotl routes attention via a single config field:
|
||||
|
||||
```yaml
|
||||
sdp_attention: true
|
||||
attn_implementation: <backend>
|
||||
```
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
`attn_implementation` is passed through to `transformers` verbatim (via
|
||||
`model.config._attn_implementation`). Accepted values are the HF-native
|
||||
backends, axolotl-registered backends, or a hub-kernel path.
|
||||
|
||||
## Flash Attention
|
||||
## Backends
|
||||
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
| `attn_implementation` | Description |
|
||||
|---|---|
|
||||
| `eager` | Plain PyTorch attention. No packing support. |
|
||||
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
|
||||
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
|
||||
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
|
||||
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
|
||||
| `xformers` | xFormers memory-efficient attention. |
|
||||
| `sage` | SageAttention (QK int8 / PV fp16). |
|
||||
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
|
||||
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
|
||||
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
|
||||
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
|
||||
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
|
||||
|
||||
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
|
||||
set the canonical name above.
|
||||
|
||||
### Capability flags
|
||||
|
||||
Axolotl derives three boolean capability flags from `attn_implementation` and
|
||||
exposes them on the validated config:
|
||||
|
||||
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
|
||||
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
|
||||
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
|
||||
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
|
||||
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
|
||||
(everything except `eager` and `sdpa`).
|
||||
|
||||
These are **computed** — they cannot be overridden from YAML.
|
||||
|
||||
## Per-backend notes
|
||||
|
||||
### SDPA
|
||||
|
||||
Default PyTorch attention. See
|
||||
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
attn_implementation: sdpa
|
||||
```
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
### Flash Attention
|
||||
|
||||
### Flash Attention 2
|
||||
Axolotl supports FA2, FA3, and FA4. The best available version is used
|
||||
automatically based on your installed packages and GPU.
|
||||
|
||||
```yaml
|
||||
attn_implementation: flash_attention_2 # or flash_attention_3
|
||||
```
|
||||
|
||||
#### Flash Attention 2
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
|
||||
@@ -39,20 +82,20 @@ Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
### Flash Attention 3
|
||||
#### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/hopper
|
||||
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
#### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
|
||||
is true and FA4 is importable.
|
||||
|
||||
FA4 is still a pre-release on PyPI, so `--pre` is required:
|
||||
|
||||
@@ -65,7 +108,6 @@ Or from source:
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
@@ -88,93 +130,113 @@ and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
Requirements: ROCm 6.0 and above. See
|
||||
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
## Flex Attention
|
||||
|
||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
||||
### Flex Attention
|
||||
|
||||
```yaml
|
||||
flex_attention: true
|
||||
|
||||
# recommended
|
||||
torch_compile: true
|
||||
attn_implementation: flex_attention
|
||||
torch_compile: true # recommended
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
|
||||
|
||||
We recommend using latest stable version of PyTorch for best performance.
|
||||
### SageAttention
|
||||
|
||||
:::
|
||||
|
||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
||||
|
||||
## SageAttention
|
||||
|
||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
||||
Requirements: Ampere, Ada, or Hopper GPUs.
|
||||
|
||||
```yaml
|
||||
sage_attention: true
|
||||
attn_implementation: sage
|
||||
```
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
```bash
|
||||
pip install sageattention==2.2.0 --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
|
||||
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## xFormers
|
||||
### xFormers
|
||||
|
||||
```yaml
|
||||
xformers_attention: true
|
||||
attn_implementation: xformers
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
We recommend using with Turing GPUs or below (such as on Colab).
|
||||
Recommended for Turing GPUs or below (e.g. Colab T4).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
||||
|
||||
## Shifted Sparse Attention
|
||||
### Shifted Sparse Attention
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
||||
Planned for deprecation. Prefer one of the backends above.
|
||||
|
||||
:::
|
||||
|
||||
Requirements: LLaMA model architecture
|
||||
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
|
||||
patched to implement shifted-sparse attention. Does not support sample packing.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
s2_attention: true
|
||||
attn_implementation: s2
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
### FP8
|
||||
|
||||
No sample packing support!
|
||||
torchao low-precision attention. Loaded as SDPA and patched post-load.
|
||||
|
||||
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
|
||||
flash-attn with FA3. KV caching must be disabled.
|
||||
|
||||
```yaml
|
||||
attn_implementation: fp8
|
||||
```
|
||||
|
||||
### Hub kernels
|
||||
|
||||
```yaml
|
||||
attn_implementation: kernels-community/flash-attn3
|
||||
```
|
||||
|
||||
Passed through to `transformers`; axolotl does not install the kernel itself.
|
||||
For recognized hub paths the capability flags are set automatically; for
|
||||
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
|
||||
`attn_uses_flash_lib=False`).
|
||||
|
||||
## Migrating from legacy boolean flags
|
||||
|
||||
The following legacy config fields are **deprecated** and will be removed in a
|
||||
future release. Each emits a `DeprecationWarning` when set and is stripped from
|
||||
the validated config.
|
||||
|
||||
| Legacy | Canonical |
|
||||
|---|---|
|
||||
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
|
||||
| `sdp_attention: true` | `attn_implementation: sdpa` |
|
||||
| `xformers_attention: true` | `attn_implementation: xformers` |
|
||||
| `flex_attention: true` | `attn_implementation: flex_attention` |
|
||||
| `sage_attention: true` | `attn_implementation: sage` |
|
||||
| `s2_attention: true` | `attn_implementation: s2` |
|
||||
| `eager_attention: true` | `attn_implementation: eager` |
|
||||
|
||||
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
|
||||
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
Existing example configs under `examples/` still use the legacy flags. They
|
||||
continue to work with a deprecation warning; they will be migrated in a
|
||||
follow-up pass.
|
||||
|
||||
:::
|
||||
|
||||
@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
|
||||
max_steps: 20
|
||||
learning_rate: 5.0e-6
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
output_dir: ./outputs/ebft-quickstart
|
||||
```
|
||||
@@ -304,7 +304,7 @@ lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flex_attention: true
|
||||
attn_implementation: flex_attention
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required with flex_attention
|
||||
|
||||
@@ -154,7 +154,7 @@ lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
|
||||
bf16: true
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
See [Attention](attention.qmd) for the full list of backends and the canonical values.
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
|
||||
@@ -1147,8 +1147,7 @@ datasets:
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
flash_attention: false
|
||||
flex_attention: true # Strided mode uses flex_attention
|
||||
attn_implementation: flex_attention # Strided mode uses flex_attention
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for flex_attention
|
||||
|
||||
@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
|
||||
|
||||
## Limitations
|
||||
|
||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
|
||||
- May have a small performance overhead due to communication between GPUs
|
||||
|
||||
## Example
|
||||
|
||||
@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
|
||||
Reduces attention memory from O(n^2) to O(n):
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
```
|
||||
|
||||
### Step 6: Offload with DeepSpeed
|
||||
|
||||
@@ -39,7 +39,7 @@ tf32: true
|
||||
gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -48,7 +48,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -50,8 +50,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_steps: 100
|
||||
saves_per_epoch: 1
|
||||
|
||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_steps: 100
|
||||
saves_per_epoch: 1
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -59,8 +59,7 @@ gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
flash_attention: true
|
||||
sdp_attention:
|
||||
attn_implementation: flash_attention_2
|
||||
flash_optimum:
|
||||
|
||||
gptq_groupsize:
|
||||
|
||||
@@ -39,8 +39,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -43,8 +43,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -73,8 +73,7 @@ early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -40,8 +40,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -36,8 +36,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -37,8 +37,7 @@ bf16: auto
|
||||
tf32: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -39,7 +39,6 @@ bf16: auto
|
||||
tf32: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -39,7 +39,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -40,7 +40,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -47,7 +47,6 @@ tf32: false
|
||||
gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -47,7 +47,6 @@ tf32: false
|
||||
gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -40,7 +40,6 @@ bf16: auto
|
||||
tf32: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -38,7 +38,6 @@ tf32: true
|
||||
gradient_checkpointing:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -44,7 +44,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -47,7 +47,6 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: false
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -36,7 +36,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -71,8 +71,7 @@ early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -10,7 +10,7 @@ load_in_4bit: true
|
||||
sequence_len: 1024
|
||||
bf16: auto
|
||||
tf32: false
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
special_tokens:
|
||||
bos_token: "<|startoftext|>"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
@@ -48,7 +48,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -50,7 +50,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
# scaling_softmax: true # needs flex_attention
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
|
||||
@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
|
||||
@@ -26,7 +26,7 @@ output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1 # must be 1 when using context parallel
|
||||
|
||||
@@ -65,8 +65,7 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -46,7 +46,7 @@ lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -66,7 +66,7 @@ lora_target_linear: true
|
||||
|
||||
# --- Hardware ---
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -47,8 +47,7 @@ lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
|
||||
attn_implementation: flex_attention
|
||||
# (full-model compile conflicts with gradient checkpointing + flex_attention)
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -46,7 +46,6 @@ lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -48,7 +48,6 @@ lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
torch_dtype: bfloat16
|
||||
flash_attention: false
|
||||
gradient_checkpointing: true
|
||||
torch_compile: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -41,7 +41,6 @@ warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT uses flex_attention at runtime
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
@@ -72,7 +72,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -63,7 +63,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -68,7 +68,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -61,7 +61,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -53,7 +53,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -58,8 +58,7 @@ gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -55,8 +55,7 @@ gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -84,7 +84,7 @@ activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA2 not supported
|
||||
sdp_attention: true
|
||||
attn_implementation: sdpa
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -62,7 +62,7 @@ activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
flex_attention: true
|
||||
attn_implementation: flex_attention
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -60,7 +60,7 @@ activation_offloading: true
|
||||
logging_steps: 1
|
||||
|
||||
# FA not supported
|
||||
sdp_attention: true
|
||||
attn_implementation: sdpa
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -50,7 +50,7 @@ gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
attn_implementation: sdpa
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -50,7 +50,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -45,7 +45,7 @@ gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
attn_implementation: sdpa
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
|
||||
@@ -42,7 +42,7 @@ tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
attn_implementation: sdpa
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
|
||||
@@ -58,7 +58,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -57,7 +57,7 @@ tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -58,7 +58,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -57,7 +57,7 @@ tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -47,7 +47,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -43,7 +43,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -44,7 +44,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -43,7 +43,6 @@ learning_rate: 2e-5
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
@@ -56,7 +56,6 @@ learning_rate: 2e-4
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
|
||||
|
||||
gradient_checkpointing: true
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user