Compare commits

..

2 Commits

Author SHA1 Message Date
NanoCode012
798c8fba89 chore: update docker docs (#3623)
Some checks failed
Publish Docs / build-deploy (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 128, 12.8.1, true, linux/amd64,linux/arm64, 3.12, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-uv (<nil>, 130, 13.0.0, linux/amd64,linux/arm64, 3.12, 2.10.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 128, 12.8.1, true, 3.11, 2.9.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 130, 13.0.0, <nil>, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / pre-commit (push) Has been cancelled
Tests Nightly against upstream main / Prefetch S3 once to prime the CDN cache (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / PyTest (3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, 1, 3.11, 2.10.0) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 128, 12.8.1, true, 1, 3.11, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-tests (<nil>, 130, 13.0.0, true, 1, 3.12, 2.9.1) (push) Has been cancelled
Tests Nightly against upstream main / docker-e2e-multigpu-tests (<nil>, 128, 12.8.1, true, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-nightlies / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (<nil>, 130, 13.0.0, 2, 3.11, 2.9.1) (push) Has been cancelled
docker-multigpu-tests-biweekly / test-axolotl-multigpu (fbgemm-gpu, 128, 12.8.1, 2, 3.11, 2.10.0) (push) Has been cancelled
Pre-commit auto-update / auto-update (push) Has been cancelled
2026-04-24 16:03:21 +07:00
NanoCode012
17fc747f99 fix: docker build failing (#3622)
* fix: uv leftover docs

* fix: docker build failing

* chore: doc

* fix: remove old pytorch build

* fix: stop recommend flash-attn optional, let transformers pull

* fix: remove ring flash attention from image

* fix: quotes [skip ci]

* chore: naming [skip ci]
2026-04-24 14:23:09 +07:00
279 changed files with 621 additions and 1643 deletions

View File

@@ -31,10 +31,11 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
# Install axolotl + dev and test dependencies from lockfile
# Install axolotl + dev and test dependencies
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
pre-commit install
# test

View File

@@ -30,14 +30,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -168,14 +160,6 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -18,12 +18,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -180,12 +174,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -1,142 +0,0 @@
# `attn-implementation-refactor` branch review
Review target: `attn-implementation-refactor` (5 commits ahead of main, merge base `69904781`).
Scope: 16 files, +682 / 96.
## 1. What the branch is trying to do
Collapse seven boolean attention flags (`flash_attention`, `sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, `s2_attention`, `eager_attention`) into a single `attn_implementation` field, with derived capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) for the gates that used to be ad-hoc OR-chains.
Mechanism: `normalize_attn_implementation` (a `@model_validator(mode="before")` on `AxolotlInputConfig`) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (`s2 + flash → s2`), and then computes the three capability flags from frozen sets in `enums.py`.
Adjacent changes: `xformers` and `sage` now register as their own entries in `ALL_ATTENTION_FUNCTIONS` (with FA2 mask behavior) instead of stomping the `flash_attention_2` slot. New `fp8` backend wires `torchao.prototype.attention.apply_low_precision_attention` in `apply_post_model_load_patches`.
## 2. Target design
**`cfg.attn_implementation` is the single source of truth on the validated config.**
- Its type is `str | None`. Accepted values are **canonical names only** — no short-form aliases:
- HF-native: `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`. (`flash_attention_3` is net-new to axolotl — the current branch only encodes `flash_attention_2` under the short name `flash`.)
- Axolotl-owned (registered into `ALL_ATTENTION_FUNCTIONS` under exactly these names): `xformers`, `sage`, `s2`, `fp8`.
- Hub-kernel paths: `kernels-community/sage-attention`, `kernels-community/flash-attn3`, etc. — passthrough. Known-kernel allowlist in `enums.py` classifies the common ones into the capability tables.
Short forms like `flash`, `fa2`, `fa3`, `sdp`, `flex` are rejected (Pydantic validation error with a pointer to the canonical name).
- `model.py:_set_attention_config` passes `cfg.attn_implementation` to HF verbatim — no `_ATTN_IMPL_TO_HF` translation dict needed.
- Legacy booleans (`flash_attention: true`, `sdp_attention: true`, …) are the **only** input aliases, kept for backwards compatibility. The normalizer maps them to the canonical `attn_implementation` value, emits a one-time `DeprecationWarning` per flag, and removes them from `data` so they're never readable on the validated `cfg`. `deprecated=True` on each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (`flash_attention → flash_attention_2`, `sdp_attention → sdpa`, `flex_attention → flex_attention`, `xformers_attention → xformers`, `sage_attention → sage`, `s2_attention → s2`, `eager_attention → eager`).
- Capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) are **`@computed_field` on the model**, not settable inputs. Lookup is keyed by the canonical `attn_implementation` string.
- Unknown / user-supplied strings (custom hub kernels) pass through to HF but get **conservative capability defaults** (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist.
- Downstream consumers read *only* `cfg.attn_implementation` and the capability flags. No `cfg.flash_attention`, `cfg.xformers_attention`, etc. anywhere in `src/`.
This is strictly what the branch is already *trying* to do — the gaps below are places it hasn't landed that goal yet.
## 3. Gaps and holes
### A. Legacy flags are still parallel state, not input-only
1. The normalizer *sets* the legacy flags on `data` (`impl_to_flag[attn_impl]` branch). It does not delete them. So `cfg.flash_attention` is still truthy after validation, and downstream code still reads it (see G).
2. Short-form enum values (`flash`, `sdpa`, `fp8`) are persisted as-is on `cfg.attn_implementation`, which is why `model.py` needs `_ATTN_IMPL_TO_HF` to translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time.
3. Legacy flag + `attn_implementation` (consistent combo, e.g. `attn_implementation: flash + flash_attention: true`) emits no deprecation warning — only legacy-only path warns.
4. Legacy Field descriptions (`xformers_attention`, `sdp_attention`, etc.) don't have `deprecated=True`, so JSON schema still advertises them as first-class.
### B. Validators that still only check the legacy flag
5. **`check_ebft_activation_offloading`** (`validation.py:1607-1619`) reads only `data.get("flex_attention")`. Users on `attn_implementation: flex_attention` bypass the incompatibility check.
6. **`check_sample_packing_without_attention`** (`validation.py:188-203`) early-returns when `attn_implementation` is set but never validates the chosen backend actually supports packing. `attn_implementation: eager + sample_packing: true` silently passes; the old legacy-flag check warned.
### C. Non-enum strings fall through the capability tables
7. **HF-native `"flash_attention_2"`** is neither in `impl_to_flag` nor `FLASH_ATTN_LIB_IMPLS`. A user copy-pasting from HF docs gets `attn_uses_flash_lib=False`, silently disabling FA4 auto-apply, LLaMA flash hijack, `_patch_attention` (btlm, stablelm_epoch, mistral3, llava), and `_apply_flash_attention_peft_patches`.
8. **Hub kernel strings** (`kernels-community/flash-attn3`, `kernels-community/sage-attention`) default to `attn_supports_packing=True` (silently enters multipack with varlen `position_ids` — correctness depends on the kernel honoring them) and `attn_uses_flash_lib=False` (so `context_parallel_size > 1` raises "requires flash attention" even for FA3 hub kernels).
9. **Conflict trap for hub-kernel + legacy flag** (`config.py:1414-1419`): `attn_implementation: kernels-community/flash-attn3 + flash_attention: true` always raises, because `impl_to_flag.get(custom) is None` and the loop treats `flag != None` as conflict. Common combo in existing YAMLs breaks hard on upgrade.
### D. Silent behaviour change for xformers
10. Old `_apply_flash_attention_patches` did `self.cfg.flash_attention = True` for `xformers + sample_packing`. The new version doesn't, and xformers is not in `FLASH_ATTN_LIB_IMPLS`. Consumers that keyed off `cfg.flash_attention` now see falsy for xformers, silently dropping `_patch_attention` (btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission.
### E. Capability flags are writable Pydantic fields, not computed
11. `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` are declared `bool | None = Field(default=None)` on `AxolotlInputConfig`. YAML is not rejected — a user can set `attn_uses_flash_lib: true` and override the normalizer.
### F. Validator ordering (not covered by tests)
12. `AttentionValidationMixin.check_attention_fields` (inherited, `mode="before"`) and `normalize_attn_implementation` (subclass, `mode="before"`) both run during `model_validator` phase. Pydantic MRO may run the inherited one first. For legacy-only `s2_attention: true + flash_attention: true` (the test `test_s2_plus_flash_maps_to_s2` asserts this maps to `s2`), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way.
### G. Remaining legacy reads in `src/`
13. `src/axolotl/integrations/lm_eval/cli.py:120` reads `cfg.flash_attention`. Works for `attn_implementation=flash` only.
14. `tests/e2e/multigpu/test_llama.py:524-526` writes `cfg.flash_attention = True` / `cfg.flex_attention = True`. Stale pattern.
15. Dual-check idioms in `config.py` (lines 1464, 1478, 1570, 1586, 1774) and `validation.py` (198, 209, 221, 850, 1586, 1611) — `data.get("x_attention") or data.get("attn_implementation") == "x"` — are redundant once legacy flags are input-only; remove them.
### H. fp8 operational risk
16. The `fp8` docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (`config.use_cache = False`). None are validated — misconfig surfaces as a torchao runtime error. `xformers` and `sage` availability/compute-capability guards exist; `fp8` should match.
### I. Test coverage gaps
17. `test_attn_implementation.py` exercises the classmethod in isolation plus the constant sets. It does **not**:
- Build a full `AxolotlInputConfig(**data)` (so validator ordering — item 12 — is untested).
- Verify capability flags can't be overridden from YAML (item 11).
- Cover `check_sample_packing_without_attention` with `attn_implementation: eager` (item 6).
- Cover `check_ebft_activation_offloading` with `attn_implementation: flex_attention` (item 5).
- Cover hub-kernel + legacy flag combo (item 9).
- Cover `flash_attention_2` canonicalization (item 7).
## 4. Fix plan
Four phases, each commit-sized. Phases 12 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs.
### Phase 1 — Lock down the data model
1. Drop the `AttnImplementation` enum. `attn_implementation` becomes `str | None`, validated against a canonical allowlist (`eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`, `xformers`, `sage`, `s2`, `fp8`) **or** a hub-kernel path (`startswith("kernels-")` / contains `/`). Reject short-form strings like `flash` / `fa2` / `sdp` / `flex` with an explicit error pointing at the canonical name.
2. Rewrite `normalize_attn_implementation` so its only job is mapping **legacy booleans → canonical `attn_implementation`** (for BC). Mapping is fixed:
- `flash_attention → flash_attention_2`
- `sdp_attention → sdpa`
- `flex_attention → flex_attention`
- `xformers_attention → xformers`
- `sage_attention → sage`
- `s2_attention → s2`
- `eager_attention → eager`
Priority for legacy combos stays as in the current branch (`s2 > sage > xformers > flex > flash > sdp > eager`). Emit a one-time `DeprecationWarning` per unique legacy flag seen. After mapping, delete the legacy flag keys from `data` so they never appear on validated `cfg`. If both a canonical `attn_implementation` and any legacy flag are set, raise (no silent precedence).
**Merge `AttentionValidationMixin.check_attention_fields` into this normalizer and delete the mixin method.** Pydantic v2 runs inherited `mode="before"` validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos like `s2 + flash` before the normalizer can map them. One validator, one source of conflict detection.
**Fix the gemma4-hybrid path**: change `data["attn_implementation"] = "flash"` to `data["attn_implementation"] = "flash_attention_2"` (the short name no longer validates after step 1).
3. Convert `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` to `@computed_field`. The three capability tables move to `enums.py` as module constants keyed by the canonical `attn_implementation` string (including `flash_attention_3` — missing from the current branch — and known hub kernels):
- Packing-capable: `{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}`.
- Flash-lib (axolotl's monkeypatch targets): `{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}`.
- No-dtype-cast: `{eager, sdpa}`.
Unknown strings: conservative defaults (`packing=False, flash_lib=False, dtype_cast=True`).
4. Delete `_ATTN_IMPL_TO_HF` from `model.py` and pass `cfg.attn_implementation` straight through. The gemma4-hybrid branch continues to override to `flash_attention_2` before passing to HF.
5. `deprecated=True` on each legacy boolean Field so JSON schema + Pydantic surface the deprecation.
### Phase 2 — Fix the validators
6. `check_sample_packing_without_attention`: drop the early-return and gate on `attn_supports_packing`. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend.
7. `check_ebft_activation_offloading`: replace `data.get("flex_attention")` with `attn_implementation == "flex_attention"`.
8. Sweep items (item 15): remove every `data.get("x_attention") or data.get("attn_implementation") == "x"` dual-check — after phase 1 the legacy side is always `None`. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs.
9. fp8 preflight (item 16): require `env_capabilities.compute_capability ≥ sm_90`, `torch_version ≥ 2.11`, and `torchao_version ≥ 0.17`. Warn if `use_cache` isn't explicitly `False`.
### Phase 3 — Migrate remaining consumers
10. `lm_eval/cli.py:120``flash_attention=cfg.attn_uses_flash_lib`.
11. `lm_eval/__init__.py:26` currently reads `(cfg.attn_implementation == "flash")` — after canonicalization `"flash"` is never stored, so this evaluates `False` for every backend. Change to `cfg.attn_uses_flash_lib`.
12. `validation.py:1137-1142` (NPU check) currently iterates `["flash_attention", "sdp_attention", "s2_attention"]` as string keys. Replace with `cfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}` or the equivalent canonical-string set.
13. `tests/e2e/multigpu/test_llama.py:524-526``cfg.attn_implementation = "flash_attention_2"` / `"flex_attention"`.
14. **Xformers decision** (item 10): the old `cfg.flash_attention = True` side-effect activated `_patch_attention` for btlm/stablelm_epoch+packing/mistral3/llava. Two choices:
- Add `xformers` to the set that gates `_patch_attention` (restore old behaviour, keeps patches live).
- Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead.
Pick one explicitly and leave a commit note. Do not leave it as silent breakage.
15. Add a repo-level check (`tests/test_no_legacy_attn_reads.py` or a ruff/grep pre-commit) that fails if anything outside `config.py`'s normalizer reads `cfg.flash_attention` / `cfg.sdp_attention` / etc. Keeps the invariant from rotting.
### Phase 4 — Tests + docs
14. Rewrite `test_attn_implementation.py` to build full `AxolotlInputConfig(**data)`, not just the classmethod. Covers validator ordering and the Pydantic-field-override issue.
15. Add one test per gap closed above: `attn_implementation: eager + sample_packing`; `attn_implementation: flex_attention + activation_offloading`; short-form `flash` rejected; `flash_attention_2` passthrough; `kernels-community/flash-attn3` capability lookup; `attn_uses_flash_lib: true` in YAML rejected; legacy boolean emits `DeprecationWarning` and is absent from validated `cfg`; fp8 preflight failures.
16. Update `docs/attention.qmd` for the single `attn_implementation` field + the deprecation table for legacy flags. One-paragraph migration note in the changelog.
17. `examples/` contains ~170 YAML files using legacy flags (`flash_attention: true` etc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them to `attn_implementation: flash_attention_2` is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass.
## 5. Ordering & risk
- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using `attn_implementation` correctly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point).
- Phase 2 is independent of phase 1 and can land first as a quick safety net.
- Phase 3 step 13 is the only judgment call — flag for review before choosing.
- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable.

View File

@@ -1 +1 @@
0.16.0.dev0
0.16.2.dev0

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -24,9 +24,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
BASE_EXTRAS="deepspeed,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
RUN if [ "$TARGETARCH" = "amd64" ]; then \
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
LINUX_TAG="manylinux_" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
arm64) ARCH_TAG="2_34_aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

View File

@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
| Backend | Config | head_dim limit | torch_compile | Notes |
|---------|--------|---------------|---------------|-------|
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
| eager | neither set | None | ✅ | Slowest, always works |
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.

View File

@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
| Issue | Fix |
|-------|-----|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
| Missing chat template error | Set `chat_template: chatml` explicitly |
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |

View File

@@ -3,71 +3,28 @@ title: Attention
description: Supported attention modules in Axolotl
---
Axolotl routes attention via a single config field:
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
attn_implementation: <backend>
sdp_attention: true
```
`attn_implementation` is passed through to `transformers` verbatim (via
`model.config._attn_implementation`). Accepted values are the HF-native
backends, axolotl-registered backends, or a hub-kernel path.
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Backends
## Flash Attention
| `attn_implementation` | Description |
|---|---|
| `eager` | Plain PyTorch attention. No packing support. |
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
| `xformers` | xFormers memory-efficient attention. |
| `sage` | SageAttention (QK int8 / PV fp16). |
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
set the canonical name above.
### Capability flags
Axolotl derives three boolean capability flags from `attn_implementation` and
exposes them on the validated config:
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
(everything except `eager` and `sdpa`).
These are **computed** — they cannot be overridden from YAML.
## Per-backend notes
### SDPA
Default PyTorch attention. See
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
```yaml
attn_implementation: sdpa
flash_attention: true
```
### Flash Attention
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
Axolotl supports FA2, FA3, and FA4. The best available version is used
automatically based on your installed packages and GPU.
```yaml
attn_implementation: flash_attention_2 # or flash_attention_3
```
#### Flash Attention 2
### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
@@ -82,20 +39,20 @@ Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
#### Flash Attention 4
### Flash Attention 4
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
is true and FA4 is importable.
Requirements: Hopper or Blackwell GPUs
```bash
pip install flash-attn-4
@@ -106,6 +63,7 @@ Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
@@ -128,113 +86,93 @@ and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above. See
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
Requirements: ROCm 6.0 and above.
### Flex Attention
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
attn_implementation: flex_attention
torch_compile: true # recommended
flex_attention: true
# recommended
torch_compile: true
```
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
::: {.callout-note}
### SageAttention
We recommend using latest stable version of PyTorch for best performance.
Requirements: Ampere, Ada, or Hopper GPUs.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
attn_implementation: sage
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
### xFormers
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
attn_implementation: xformers
xformers_attention: true
```
::: {.callout-tip}
Recommended for Turing GPUs or below (e.g. Colab T4).
We recommend using with Turing GPUs or below (such as on Colab).
:::
### Shifted Sparse Attention
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
Planned for deprecation. Prefer one of the backends above.
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
patched to implement shifted-sparse attention. Does not support sample packing.
Requirements: LLaMA model architecture
```yaml
attn_implementation: s2
flash_attention: true
s2_attention: true
```
### FP8
::: {.callout-tip}
torchao low-precision attention. Loaded as SDPA and patched post-load.
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
flash-attn with FA3. KV caching must be disabled.
```yaml
attn_implementation: fp8
```
### Hub kernels
```yaml
attn_implementation: kernels-community/flash-attn3
```
Passed through to `transformers`; axolotl does not install the kernel itself.
For recognized hub paths the capability flags are set automatically; for
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
`attn_uses_flash_lib=False`).
## Migrating from legacy boolean flags
The following legacy config fields are **deprecated** and will be removed in a
future release. Each emits a `DeprecationWarning` when set and is stripped from
the validated config.
| Legacy | Canonical |
|---|---|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
| `sdp_attention: true` | `attn_implementation: sdpa` |
| `xformers_attention: true` | `attn_implementation: xformers` |
| `flex_attention: true` | `attn_implementation: flex_attention` |
| `sage_attention: true` | `attn_implementation: sage` |
| `s2_attention: true` | `attn_implementation: s2` |
| `eager_attention: true` | `attn_implementation: eager` |
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
::: {.callout-note}
Existing example configs under `examples/` still use the legacy flags. They
continue to work with a deprecation warning; they will be migrated in a
follow-up pass.
No sample packing support!
:::

View File

@@ -77,8 +77,9 @@ Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/us
```bash
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed --group dev --group test
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
#### Remote Hosts
@@ -218,8 +219,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
You will now be in the container. Next, install Axolotl with dev dependencies:
```bash
uv sync --extra flash-attn --extra deepspeed --group dev --group test
uv venv --no-project --relocatable
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
```
### Attach To Container

View File

@@ -10,13 +10,16 @@ This section describes the different Docker images that are released by AxolotlA
[Docker Hub](https://hub.docker.com/u/axolotlai).
::: {.callout-important}
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
:::
### Switch to the `-uv` images
::: {.callout-tip}
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
same format as their non-uv counterparts.
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
the uv image.
:::
## Base
@@ -85,7 +88,7 @@ Tags examples:
- `main-py3.12-cu130-2.10.0`
- `main-latest`
- `main-20260315-py3.11-cu128-2.9.1`
- `0.12.0`
- `0.16.1`
## Cloud

View File

@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
max_steps: 20
learning_rate: 5.0e-6
bf16: auto
attn_implementation: flash_attention_2
flash_attention: true
gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstart
```
@@ -304,7 +304,7 @@ lora_alpha: 32
lora_target_linear: true
bf16: auto
attn_implementation: flex_attention
flex_attention: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attention

View File

@@ -57,7 +57,7 @@ description: Frequently asked questions
**Q: vLLM is not working with Axolotl**
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**

View File

@@ -154,7 +154,7 @@ lr_scheduler: cosine
warmup_steps: 10
bf16: true
attn_implementation: flash_attention_2
flash_attention: true
gradient_checkpointing: true
special_tokens:

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.9.0
- PyTorch ≥2.9.1
## Installation {#sec-installation}
@@ -36,9 +36,9 @@ source $HOME/.local/bin/env
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
```{.bash}
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
uv venv
source .venv/bin/activate
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
uv pip install --no-build-isolation axolotl[deepspeed]
```
### Edge/Development Build {#sec-edge-build}
@@ -49,12 +49,11 @@ For the latest features between releases:
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
export UV_TORCH_BACKEND=cu128 # or cu130
uv sync --extra flash-attn --extra deepspeed
uv venv
source .venv/bin/activate
uv pip install --no-build-isolation -e '.[deepspeed]'
```
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
### Docker {#sec-docker}
```{.bash}
@@ -132,11 +131,11 @@ source $HOME/.local/bin/env
# Create a fresh venv (recommended for a clean start)
export UV_TORCH_BACKEND=cu128 # or cu130
uv venv --no-project --relocatable
uv venv
source .venv/bin/activate
# Reinstall axolotl
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
uv pip install --no-build-isolation axolotl[deepspeed]
```
## Using pip (Alternative) {#sec-pip}
@@ -151,13 +150,13 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
pip3 install --no-build-isolation axolotl[deepspeed]
```
For editable/development installs:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
pip3 install --no-build-isolation -e '.[deepspeed]'
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
See [Attention](attention.qmd) for the full list of backends and the canonical values.
*Note: You should only enable one attention backend.*
### LoRA Optimizations

View File

@@ -1147,7 +1147,8 @@ datasets:
type: ebft_strided_structured.transform
split: train[:1%]
attn_implementation: flex_attention # Strided mode uses flex_attention
flash_attention: false
flex_attention: true # Strided mode uses flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required for flex_attention

View File

@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
## Limitations
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example

View File

@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
Reduces attention memory from O(n^2) to O(n):
```yaml
attn_implementation: flash_attention_2
flash_attention: true
```
### Step 6: Offload with DeepSpeed

View File

@@ -15,7 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Run one of the finetuning examples below.

View File

@@ -39,7 +39,7 @@ tf32: true
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -50,7 +50,8 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1

View File

@@ -39,7 +39,7 @@ activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1

View File

@@ -11,11 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -13,11 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
uv pip install --no-build-isolation -e '.[flash-attn]'
uv pip install --no-build-isolation -e '.'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh

View File

@@ -55,7 +55,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -59,7 +59,8 @@ gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
sdp_attention:
flash_optimum:
gptq_groupsize:

View File

@@ -39,7 +39,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: xformers
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -43,7 +43,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: xformers
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -73,7 +73,8 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
attn_implementation: xformers
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -40,7 +40,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: xformers
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -36,7 +36,8 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: xformers
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -37,7 +37,8 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
attn_implementation: xformers
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -39,6 +39,7 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -39,7 +39,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -40,7 +40,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -47,6 +47,7 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -47,6 +47,7 @@ tf32: false
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention:
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -40,6 +40,7 @@ bf16: auto
tf32: true
resume_from_checkpoint:
logging_steps: 5
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -38,6 +38,7 @@ tf32: true
gradient_checkpointing:
resume_from_checkpoint:
logging_steps: 1
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -44,7 +44,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_mlp: true

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true

View File

@@ -46,7 +46,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -47,6 +47,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: false
warmup_ratio: 0.1
evals_per_epoch: 0

View File

@@ -45,7 +45,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -36,7 +36,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -47,7 +47,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4

View File

@@ -71,7 +71,8 @@ early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
logging_steps: 1
attn_implementation: xformers
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_ratio: 0.1

View File

@@ -10,7 +10,7 @@ load_in_4bit: true
sequence_len: 1024
bf16: auto
tf32: false
attn_implementation: flash_attention_2
flash_attention: true
special_tokens:
bos_token: "<|startoftext|>"
eos_token: "<|endoftext|>"

View File

@@ -48,7 +48,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -36,12 +36,7 @@
"id": "msOCO4NRmRLa"
},
"outputs": [],
"source": [
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
]
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
},
{
"cell_type": "markdown",

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -45,7 +45,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2

View File

@@ -15,8 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
uv pip install --no-build-isolation 'axolotl>=0.16.1'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage

View File

@@ -51,7 +51,7 @@ tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
scaling_softmax: true
loss_watchdog_threshold: 5.0

View File

@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
attn_implementation: flash_attention_2
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1

View File

@@ -26,7 +26,7 @@ output_dir: ./outputs/ndp-out/
sequence_len: 8192
sample_packing: true
attn_implementation: flash_attention_2
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel

View File

@@ -65,7 +65,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
attn_implementation: flash_attention_2
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0

View File

@@ -46,7 +46,7 @@ lora_dropout: 0.05
lora_target_linear: true
bf16: auto
attn_implementation: flash_attention_2
flash_attention: true
gradient_checkpointing: true
special_tokens:

View File

@@ -66,7 +66,7 @@ lora_target_linear: true
# --- Hardware ---
bf16: auto
attn_implementation: flash_attention_2
flash_attention: true
gradient_checkpointing: true
special_tokens:

View File

@@ -47,7 +47,8 @@ lora_dropout: 0.05
lora_target_linear: true
bf16: auto
attn_implementation: flex_attention
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
# (full-model compile conflicts with gradient checkpointing + flex_attention)
gradient_checkpointing: true
gradient_checkpointing_kwargs:

View File

@@ -46,6 +46,7 @@ lora_dropout: 0.05
lora_target_linear: true
bf16: auto
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
gradient_checkpointing: true
special_tokens:

View File

@@ -48,6 +48,7 @@ lora_target_linear: true
bf16: auto
torch_dtype: bfloat16
flash_attention: false
gradient_checkpointing: true
torch_compile: true
gradient_checkpointing_kwargs:

View File

@@ -41,6 +41,7 @@ warmup_steps: 10
weight_decay: 0.01
bf16: auto
flash_attention: false # strided EBFT uses flex_attention at runtime
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false

View File

@@ -72,7 +72,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
attn_implementation: flash_attention_2
flash_attention: true
gradient_checkpointing: true
special_tokens:

View File

@@ -63,7 +63,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
attn_implementation: flash_attention_2
flash_attention: true
gradient_checkpointing: true
special_tokens:

View File

@@ -68,7 +68,7 @@ lora_dropout: 0.0
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
bf16: auto
attn_implementation: flash_attention_2
flash_attention: true
gradient_checkpointing: true
special_tokens:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -61,7 +61,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -53,7 +53,7 @@ tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

View File

@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
attn_implementation: flash_attention_2
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:

Some files were not shown because too many files have changed in this diff Show More