Compare commits
23 Commits
vllm-0191
...
attn-imple
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
464de78f6d | ||
|
|
7a41b47d22 | ||
|
|
6886def92c | ||
|
|
aeca18a8b0 | ||
|
|
434a484fe9 | ||
|
|
39226623d2 | ||
|
|
2d64d009d8 | ||
|
|
a0d24bcc19 | ||
|
|
bce65e3332 | ||
|
|
2579c496d5 | ||
|
|
35d43fe141 | ||
|
|
ff5d6393c8 | ||
|
|
aee8c75d64 | ||
|
|
c4f986874d | ||
|
|
28e89a5c16 | ||
|
|
901f2356bc | ||
|
|
1bf65c500e | ||
|
|
bcbe049c21 | ||
|
|
90090fa9e8 | ||
|
|
7420fd4de6 | ||
|
|
05113bc91a | ||
|
|
e562e149ce | ||
|
|
9de5b76336 |
5
.github/CONTRIBUTING.md
vendored
5
.github/CONTRIBUTING.md
vendored
@@ -31,7 +31,10 @@ PRs are **greatly welcome**!
|
||||
|
||||
Please run below to setup env
|
||||
```bash
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
# Install axolotl + dev and test dependencies from lockfile
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
pre-commit install
|
||||
|
||||
# test
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -6,7 +6,7 @@ on:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/*.yml'
|
||||
- "*.[q]md"
|
||||
- "examples/**/*.y[a]?ml"
|
||||
|
||||
35
.github/workflows/multi-gpu-e2e.yml
vendored
35
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,17 +3,15 @@ name: docker-multigpu-tests-biweekly
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/**.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'scripts/cutcrossentropy_install.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
- "tests/e2e/multigpu/**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/multi-gpu-e2e.yml"
|
||||
- "scripts/cutcrossentropy_install.py"
|
||||
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
|
||||
- "src/axolotl/utils/distributed.py"
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
@@ -33,19 +31,19 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -53,7 +51,6 @@ jobs:
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -75,7 +72,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
13
.github/workflows/pypi.yml
vendored
13
.github/workflows/pypi.yml
vendored
@@ -8,6 +8,9 @@ on:
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
@@ -41,11 +44,15 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging==26.0
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
uv pip install wheel packaging
|
||||
uv pip install --no-build-isolation -e .
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
|
||||
55
.github/workflows/tests-nightly.yml
vendored
55
.github/workflows/tests-nightly.yml
vendored
@@ -2,15 +2,18 @@ name: Tests Nightly against upstream main
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '.github/workflows/tests-nightly.yml'
|
||||
- ".github/workflows/tests-nightly.yml"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
@@ -20,7 +23,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
cache: "pip" # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
@@ -43,7 +46,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
@@ -61,36 +64,34 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
|
||||
- name: Update requirements.txt
|
||||
run: |
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Override with nightly HF packages
|
||||
run: |
|
||||
uv pip install --no-deps \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||
"datasets @ git+https://github.com/huggingface/datasets.git@main"
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -102,9 +103,6 @@ jobs:
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
@@ -136,7 +134,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -157,7 +154,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
|
||||
85
.github/workflows/tests.yml
vendored
85
.github/workflows/tests.yml
vendored
@@ -6,21 +6,19 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
- "**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/*.yml"
|
||||
- "cicd/cicd.sh"
|
||||
- "cicd/Dockerfile-uv.jinja"
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- 'requirements-tests.txt'
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- "**.py"
|
||||
- "pyproject.toml"
|
||||
- ".github/workflows/*.yml"
|
||||
- "cicd/cicd.sh"
|
||||
- "cicd/Dockerfile-uv.jinja"
|
||||
workflow_dispatch:
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
@@ -33,6 +31,7 @@ permissions:
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
UV_SYSTEM_PYTHON: "1"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -44,7 +43,7 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
cache: "pip" # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
@@ -94,32 +93,25 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -188,33 +180,27 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
uv pip install packaging setuptools_scm build wheel psutil
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
|
||||
python scripts/cutcrossentropy_install.py --uv | sh
|
||||
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
@@ -291,7 +277,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -312,7 +297,7 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
@@ -374,7 +359,7 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
|
||||
@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
|
||||
| Method | Config Key | When to Use |
|
||||
|--------|-----------|-------------|
|
||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
||||
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
|
||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||
|
||||
142
ATTN_REFACTOR_REVIEW.md
Normal file
142
ATTN_REFACTOR_REVIEW.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# `attn-implementation-refactor` branch review
|
||||
|
||||
Review target: `attn-implementation-refactor` (5 commits ahead of main, merge base `69904781`).
|
||||
Scope: 16 files, +682 / −96.
|
||||
|
||||
## 1. What the branch is trying to do
|
||||
|
||||
Collapse seven boolean attention flags (`flash_attention`, `sdp_attention`, `xformers_attention`, `flex_attention`, `sage_attention`, `s2_attention`, `eager_attention`) into a single `attn_implementation` field, with derived capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) for the gates that used to be ad-hoc OR-chains.
|
||||
|
||||
Mechanism: `normalize_attn_implementation` (a `@model_validator(mode="before")` on `AxolotlInputConfig`) maps bidirectionally between the new field and the legacy flags, with a priority list for legacy combos (`s2 + flash → s2`), and then computes the three capability flags from frozen sets in `enums.py`.
|
||||
|
||||
Adjacent changes: `xformers` and `sage` now register as their own entries in `ALL_ATTENTION_FUNCTIONS` (with FA2 mask behavior) instead of stomping the `flash_attention_2` slot. New `fp8` backend wires `torchao.prototype.attention.apply_low_precision_attention` in `apply_post_model_load_patches`.
|
||||
|
||||
## 2. Target design
|
||||
|
||||
**`cfg.attn_implementation` is the single source of truth on the validated config.**
|
||||
|
||||
- Its type is `str | None`. Accepted values are **canonical names only** — no short-form aliases:
|
||||
- HF-native: `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`. (`flash_attention_3` is net-new to axolotl — the current branch only encodes `flash_attention_2` under the short name `flash`.)
|
||||
- Axolotl-owned (registered into `ALL_ATTENTION_FUNCTIONS` under exactly these names): `xformers`, `sage`, `s2`, `fp8`.
|
||||
- Hub-kernel paths: `kernels-community/sage-attention`, `kernels-community/flash-attn3`, etc. — passthrough. Known-kernel allowlist in `enums.py` classifies the common ones into the capability tables.
|
||||
Short forms like `flash`, `fa2`, `fa3`, `sdp`, `flex` are rejected (Pydantic validation error with a pointer to the canonical name).
|
||||
- `model.py:_set_attention_config` passes `cfg.attn_implementation` to HF verbatim — no `_ATTN_IMPL_TO_HF` translation dict needed.
|
||||
- Legacy booleans (`flash_attention: true`, `sdp_attention: true`, …) are the **only** input aliases, kept for backwards compatibility. The normalizer maps them to the canonical `attn_implementation` value, emits a one-time `DeprecationWarning` per flag, and removes them from `data` so they're never readable on the validated `cfg`. `deprecated=True` on each Field surfaces this in JSON schema. Mapping is 1:1 with the current legacy-flag semantics (`flash_attention → flash_attention_2`, `sdp_attention → sdpa`, `flex_attention → flex_attention`, `xformers_attention → xformers`, `sage_attention → sage`, `s2_attention → s2`, `eager_attention → eager`).
|
||||
- Capability flags (`attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast`) are **`@computed_field` on the model**, not settable inputs. Lookup is keyed by the canonical `attn_implementation` string.
|
||||
- Unknown / user-supplied strings (custom hub kernels) pass through to HF but get **conservative capability defaults** (packing=False, flash-lib=False, dtype-cast=True). Known hub kernels axolotl can classify live in a small allowlist.
|
||||
- Downstream consumers read *only* `cfg.attn_implementation` and the capability flags. No `cfg.flash_attention`, `cfg.xformers_attention`, etc. anywhere in `src/`.
|
||||
|
||||
This is strictly what the branch is already *trying* to do — the gaps below are places it hasn't landed that goal yet.
|
||||
|
||||
## 3. Gaps and holes
|
||||
|
||||
### A. Legacy flags are still parallel state, not input-only
|
||||
|
||||
1. The normalizer *sets* the legacy flags on `data` (`impl_to_flag[attn_impl]` branch). It does not delete them. So `cfg.flash_attention` is still truthy after validation, and downstream code still reads it (see G).
|
||||
2. Short-form enum values (`flash`, `sdpa`, `fp8`) are persisted as-is on `cfg.attn_implementation`, which is why `model.py` needs `_ATTN_IMPL_TO_HF` to translate before passing to HF. Source-of-truth implies canonicalize at normalize-time, not translate at consume-time.
|
||||
3. Legacy flag + `attn_implementation` (consistent combo, e.g. `attn_implementation: flash + flash_attention: true`) emits no deprecation warning — only legacy-only path warns.
|
||||
4. Legacy Field descriptions (`xformers_attention`, `sdp_attention`, etc.) don't have `deprecated=True`, so JSON schema still advertises them as first-class.
|
||||
|
||||
### B. Validators that still only check the legacy flag
|
||||
|
||||
5. **`check_ebft_activation_offloading`** (`validation.py:1607-1619`) reads only `data.get("flex_attention")`. Users on `attn_implementation: flex_attention` bypass the incompatibility check.
|
||||
6. **`check_sample_packing_without_attention`** (`validation.py:188-203`) early-returns when `attn_implementation` is set but never validates the chosen backend actually supports packing. `attn_implementation: eager + sample_packing: true` silently passes; the old legacy-flag check warned.
|
||||
|
||||
### C. Non-enum strings fall through the capability tables
|
||||
|
||||
7. **HF-native `"flash_attention_2"`** is neither in `impl_to_flag` nor `FLASH_ATTN_LIB_IMPLS`. A user copy-pasting from HF docs gets `attn_uses_flash_lib=False`, silently disabling FA4 auto-apply, LLaMA flash hijack, `_patch_attention` (btlm, stablelm_epoch, mistral3, llava), and `_apply_flash_attention_peft_patches`.
|
||||
8. **Hub kernel strings** (`kernels-community/flash-attn3`, `kernels-community/sage-attention`) default to `attn_supports_packing=True` (silently enters multipack with varlen `position_ids` — correctness depends on the kernel honoring them) and `attn_uses_flash_lib=False` (so `context_parallel_size > 1` raises "requires flash attention" even for FA3 hub kernels).
|
||||
9. **Conflict trap for hub-kernel + legacy flag** (`config.py:1414-1419`): `attn_implementation: kernels-community/flash-attn3 + flash_attention: true` always raises, because `impl_to_flag.get(custom) is None` and the loop treats `flag != None` as conflict. Common combo in existing YAMLs breaks hard on upgrade.
|
||||
|
||||
### D. Silent behaviour change for xformers
|
||||
|
||||
10. Old `_apply_flash_attention_patches` did `self.cfg.flash_attention = True` for `xformers + sample_packing`. The new version doesn't, and xformers is not in `FLASH_ATTN_LIB_IMPLS`. Consumers that keyed off `cfg.flash_attention` now see falsy for xformers, silently dropping `_patch_attention` (btlm / stablelm_epoch+packing / mistral3 / llava model-type FA patches). Some of this is arguably correct cleanup (xformers has its own HF registry entry now), but the btlm/stablelm/mistral3 regression is not called out and not tested. Decide consciously, not by omission.
|
||||
|
||||
### E. Capability flags are writable Pydantic fields, not computed
|
||||
|
||||
11. `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` are declared `bool | None = Field(default=None)` on `AxolotlInputConfig`. YAML is not rejected — a user can set `attn_uses_flash_lib: true` and override the normalizer.
|
||||
|
||||
### F. Validator ordering (not covered by tests)
|
||||
|
||||
12. `AttentionValidationMixin.check_attention_fields` (inherited, `mode="before"`) and `normalize_attn_implementation` (subclass, `mode="before"`) both run during `model_validator` phase. Pydantic MRO may run the inherited one first. For legacy-only `s2_attention: true + flash_attention: true` (the test `test_s2_plus_flash_maps_to_s2` asserts this maps to `s2`), the inherited check may raise "multiple attention implementations set" before the normalizer runs. The test calls the classmethod directly and does not build the model, so this is unverified either way.
|
||||
|
||||
### G. Remaining legacy reads in `src/`
|
||||
|
||||
13. `src/axolotl/integrations/lm_eval/cli.py:120` reads `cfg.flash_attention`. Works for `attn_implementation=flash` only.
|
||||
14. `tests/e2e/multigpu/test_llama.py:524-526` writes `cfg.flash_attention = True` / `cfg.flex_attention = True`. Stale pattern.
|
||||
15. Dual-check idioms in `config.py` (lines 1464, 1478, 1570, 1586, 1774) and `validation.py` (198, 209, 221, 850, 1586, 1611) — `data.get("x_attention") or data.get("attn_implementation") == "x"` — are redundant once legacy flags are input-only; remove them.
|
||||
|
||||
### H. fp8 operational risk
|
||||
|
||||
16. The `fp8` docstring documents hard requirements (PyTorch ≥ 2.11, SM90+, flash-attn with FA3, torchao ≥ 0.17.0) and a runtime constraint (`config.use_cache = False`). None are validated — misconfig surfaces as a torchao runtime error. `xformers` and `sage` availability/compute-capability guards exist; `fp8` should match.
|
||||
|
||||
### I. Test coverage gaps
|
||||
|
||||
17. `test_attn_implementation.py` exercises the classmethod in isolation plus the constant sets. It does **not**:
|
||||
- Build a full `AxolotlInputConfig(**data)` (so validator ordering — item 12 — is untested).
|
||||
- Verify capability flags can't be overridden from YAML (item 11).
|
||||
- Cover `check_sample_packing_without_attention` with `attn_implementation: eager` (item 6).
|
||||
- Cover `check_ebft_activation_offloading` with `attn_implementation: flex_attention` (item 5).
|
||||
- Cover hub-kernel + legacy flag combo (item 9).
|
||||
- Cover `flash_attention_2` canonicalization (item 7).
|
||||
|
||||
## 4. Fix plan
|
||||
|
||||
Four phases, each commit-sized. Phases 1–2 are local and low-risk; phase 3 is the behaviour-changing cleanup; phase 4 is tests + docs.
|
||||
|
||||
### Phase 1 — Lock down the data model
|
||||
|
||||
1. Drop the `AttnImplementation` enum. `attn_implementation` becomes `str | None`, validated against a canonical allowlist (`eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention`, `xformers`, `sage`, `s2`, `fp8`) **or** a hub-kernel path (`startswith("kernels-")` / contains `/`). Reject short-form strings like `flash` / `fa2` / `sdp` / `flex` with an explicit error pointing at the canonical name.
|
||||
2. Rewrite `normalize_attn_implementation` so its only job is mapping **legacy booleans → canonical `attn_implementation`** (for BC). Mapping is fixed:
|
||||
- `flash_attention → flash_attention_2`
|
||||
- `sdp_attention → sdpa`
|
||||
- `flex_attention → flex_attention`
|
||||
- `xformers_attention → xformers`
|
||||
- `sage_attention → sage`
|
||||
- `s2_attention → s2`
|
||||
- `eager_attention → eager`
|
||||
Priority for legacy combos stays as in the current branch (`s2 > sage > xformers > flex > flash > sdp > eager`). Emit a one-time `DeprecationWarning` per unique legacy flag seen. After mapping, delete the legacy flag keys from `data` so they never appear on validated `cfg`. If both a canonical `attn_implementation` and any legacy flag are set, raise (no silent precedence).
|
||||
|
||||
**Merge `AttentionValidationMixin.check_attention_fields` into this normalizer and delete the mixin method.** Pydantic v2 runs inherited `mode="before"` validators before subclass ones per MRO, so leaving them as siblings causes the inherited check to reject legacy combos like `s2 + flash` before the normalizer can map them. One validator, one source of conflict detection.
|
||||
|
||||
**Fix the gemma4-hybrid path**: change `data["attn_implementation"] = "flash"` to `data["attn_implementation"] = "flash_attention_2"` (the short name no longer validates after step 1).
|
||||
3. Convert `attn_supports_packing`, `attn_uses_flash_lib`, `attn_needs_dtype_cast` to `@computed_field`. The three capability tables move to `enums.py` as module constants keyed by the canonical `attn_implementation` string (including `flash_attention_3` — missing from the current branch — and known hub kernels):
|
||||
- Packing-capable: `{flash_attention_2, flash_attention_3, flex_attention, xformers, sage, kernels-community/flash-attn3, kernels-community/sage-attention}`.
|
||||
- Flash-lib (axolotl's monkeypatch targets): `{flash_attention_2, flash_attention_3, s2, kernels-community/flash-attn3}`.
|
||||
- No-dtype-cast: `{eager, sdpa}`.
|
||||
Unknown strings: conservative defaults (`packing=False, flash_lib=False, dtype_cast=True`).
|
||||
4. Delete `_ATTN_IMPL_TO_HF` from `model.py` and pass `cfg.attn_implementation` straight through. The gemma4-hybrid branch continues to override to `flash_attention_2` before passing to HF.
|
||||
5. `deprecated=True` on each legacy boolean Field so JSON schema + Pydantic surface the deprecation.
|
||||
|
||||
### Phase 2 — Fix the validators
|
||||
|
||||
6. `check_sample_packing_without_attention`: drop the early-return and gate on `attn_supports_packing`. Warn (or raise — pick one and be consistent) if packing is enabled with a non-packing backend.
|
||||
7. `check_ebft_activation_offloading`: replace `data.get("flex_attention")` with `attn_implementation == "flex_attention"`.
|
||||
8. Sweep items (item 15): remove every `data.get("x_attention") or data.get("attn_implementation") == "x"` dual-check — after phase 1 the legacy side is always `None`. Reduces ~10 lines of noise and eliminates the "which side wins" class of bugs.
|
||||
9. fp8 preflight (item 16): require `env_capabilities.compute_capability ≥ sm_90`, `torch_version ≥ 2.11`, and `torchao_version ≥ 0.17`. Warn if `use_cache` isn't explicitly `False`.
|
||||
|
||||
### Phase 3 — Migrate remaining consumers
|
||||
|
||||
10. `lm_eval/cli.py:120` → `flash_attention=cfg.attn_uses_flash_lib`.
|
||||
11. `lm_eval/__init__.py:26` currently reads `(cfg.attn_implementation == "flash")` — after canonicalization `"flash"` is never stored, so this evaluates `False` for every backend. Change to `cfg.attn_uses_flash_lib`.
|
||||
12. `validation.py:1137-1142` (NPU check) currently iterates `["flash_attention", "sdp_attention", "s2_attention"]` as string keys. Replace with `cfg.attn_implementation in {"flash_attention_2", "flash_attention_3", "sdpa", "s2"}` or the equivalent canonical-string set.
|
||||
13. `tests/e2e/multigpu/test_llama.py:524-526` → `cfg.attn_implementation = "flash_attention_2"` / `"flex_attention"`.
|
||||
14. **Xformers decision** (item 10): the old `cfg.flash_attention = True` side-effect activated `_patch_attention` for btlm/stablelm_epoch+packing/mistral3/llava. Two choices:
|
||||
- Add `xformers` to the set that gates `_patch_attention` (restore old behaviour, keeps patches live).
|
||||
- Document that those patches don't apply to xformers post-refactor and drop the paths if they're dead.
|
||||
Pick one explicitly and leave a commit note. Do not leave it as silent breakage.
|
||||
15. Add a repo-level check (`tests/test_no_legacy_attn_reads.py` or a ruff/grep pre-commit) that fails if anything outside `config.py`'s normalizer reads `cfg.flash_attention` / `cfg.sdp_attention` / etc. Keeps the invariant from rotting.
|
||||
|
||||
### Phase 4 — Tests + docs
|
||||
|
||||
14. Rewrite `test_attn_implementation.py` to build full `AxolotlInputConfig(**data)`, not just the classmethod. Covers validator ordering and the Pydantic-field-override issue.
|
||||
15. Add one test per gap closed above: `attn_implementation: eager + sample_packing`; `attn_implementation: flex_attention + activation_offloading`; short-form `flash` rejected; `flash_attention_2` passthrough; `kernels-community/flash-attn3` capability lookup; `attn_uses_flash_lib: true` in YAML rejected; legacy boolean emits `DeprecationWarning` and is absent from validated `cfg`; fp8 preflight failures.
|
||||
16. Update `docs/attention.qmd` for the single `attn_implementation` field + the deprecation table for legacy flags. One-paragraph migration note in the changelog.
|
||||
17. `examples/` contains ~170 YAML files using legacy flags (`flash_attention: true` etc.). They still validate post-refactor (normalizer maps them with deprecation), but a follow-up sweep to convert them to `attn_implementation: flash_attention_2` is worth scheduling — call this out in the migration note so users know examples will be migrated on a later pass.
|
||||
|
||||
## 5. Ordering & risk
|
||||
|
||||
- Phase 1 is the keystone: it's the largest diff but each step is mechanical once the alias map is in place. No behaviour change for any consumer that was using `attn_implementation` correctly; behaviour change only for consumers that were reading legacy flags (phase 3 step 13 is the explicit decision point).
|
||||
- Phase 2 is independent of phase 1 and can land first as a quick safety net.
|
||||
- Phase 3 step 13 is the only judgment call — flag for review before choosing.
|
||||
- Total: ~10-13 commits beyond what's on the branch, each scoped and individually revertable.
|
||||
@@ -1,7 +1,6 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
include VERSION
|
||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||
include AGENTS.md
|
||||
recursive-include docs/agents *.md
|
||||
|
||||
26
README.md
26
README.md
@@ -95,14 +95,11 @@ Features:
|
||||
|
||||
### Installation
|
||||
|
||||
#### Using uv (recommended)
|
||||
|
||||
```bash
|
||||
# install uv if you don't already have it installed
|
||||
# install uv if you don't already have it installed (restart shell after)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
|
||||
# CUDA 12.8.1 tends to have better package compatibility
|
||||
# change depending on system
|
||||
export UV_TORCH_BACKEND=cu128
|
||||
|
||||
# create a new virtual environment
|
||||
@@ -112,23 +109,6 @@ source .venv/bin/activate
|
||||
uv pip install torch==2.10.0 torchvision
|
||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||
|
||||
# recommended - install cut-cross-entropy
|
||||
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||
|
||||
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
|
||||
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
```
|
||||
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
@@ -138,7 +118,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
|
||||
|
||||
Installing with Docker can be less error prone than installing in your own environment.
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
@@ -134,7 +134,6 @@ quartodoc:
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- monkeypatch.gradient_checkpointing.offload_cpu
|
||||
@@ -327,7 +326,6 @@ website:
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
- docs/fsdp_qlora.qmd
|
||||
- docs/unsloth.qmd
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
|
||||
@@ -22,15 +22,6 @@ WORKDIR /workspace/axolotl
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
@@ -40,11 +31,21 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py --uv | sh
|
||||
# Override with nightly HF packages for nightly builds
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
uv pip install --no-deps \
|
||||
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
|
||||
fi
|
||||
|
||||
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
|
||||
|
||||
set -o pipefail
|
||||
for i in 1 2 3; do
|
||||
|
||||
@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
|
||||
@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
|
||||
@@ -32,7 +32,7 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \ python scripts/unsloth_install.py | sh && \
|
||||
fi && \
|
||||
python scripts/cutcrossentropy_install.py | sh && \
|
||||
pip install pytest && \
|
||||
pip cache purge
|
||||
|
||||
@@ -33,7 +33,6 @@ RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py --uv | sh && \
|
||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||
uv pip install pytest && \
|
||||
uv cache clean
|
||||
|
||||
@@ -121,11 +121,11 @@ Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2
|
||||
|
||||
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||
|---------|--------|---------------|---------------|-------|
|
||||
| FA2 | `flash_attention: true` | 256 | ✅ | Fastest when supported |
|
||||
| FA4 | auto with `flash_attention: true` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||
| SDPA | `sdp_attention: true` | None | ✅ | Universal fallback |
|
||||
| flex | `flex_attention: true` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||
| eager | neither set | None | ✅ | Slowest, always works |
|
||||
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
|
||||
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
|
||||
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
|
||||
|
||||
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
|
||||
|
||||
1. Paired preference data (chosen + rejected)?
|
||||
- Default → `rl: dpo`
|
||||
- Overfitting → `rl: ipo`
|
||||
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
|
||||
- VRAM-limited → `rl: orpo` (no ref model)
|
||||
- Length-sensitive → `rl: simpo` (no ref model)
|
||||
2. Only binary labels (good/bad)? → `rl: kto`
|
||||
|
||||
@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
|
||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
|
||||
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
||||
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
||||
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
||||
|
||||
@@ -3,28 +3,71 @@ title: Attention
|
||||
description: Supported attention modules in Axolotl
|
||||
---
|
||||
|
||||
## SDP Attention
|
||||
|
||||
This is the default built-in attention in PyTorch.
|
||||
Axolotl routes attention via a single config field:
|
||||
|
||||
```yaml
|
||||
sdp_attention: true
|
||||
attn_implementation: <backend>
|
||||
```
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
`attn_implementation` is passed through to `transformers` verbatim (via
|
||||
`model.config._attn_implementation`). Accepted values are the HF-native
|
||||
backends, axolotl-registered backends, or a hub-kernel path.
|
||||
|
||||
## Flash Attention
|
||||
## Backends
|
||||
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
| `attn_implementation` | Description |
|
||||
|---|---|
|
||||
| `eager` | Plain PyTorch attention. No packing support. |
|
||||
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
|
||||
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
|
||||
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
|
||||
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
|
||||
| `xformers` | xFormers memory-efficient attention. |
|
||||
| `sage` | SageAttention (QK int8 / PV fp16). |
|
||||
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
|
||||
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
|
||||
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
|
||||
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
|
||||
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
|
||||
|
||||
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
|
||||
set the canonical name above.
|
||||
|
||||
### Capability flags
|
||||
|
||||
Axolotl derives three boolean capability flags from `attn_implementation` and
|
||||
exposes them on the validated config:
|
||||
|
||||
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
|
||||
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
|
||||
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
|
||||
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
|
||||
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
|
||||
(everything except `eager` and `sdpa`).
|
||||
|
||||
These are **computed** — they cannot be overridden from YAML.
|
||||
|
||||
## Per-backend notes
|
||||
|
||||
### SDPA
|
||||
|
||||
Default PyTorch attention. See
|
||||
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
attn_implementation: sdpa
|
||||
```
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
### Flash Attention
|
||||
|
||||
### Flash Attention 2
|
||||
Axolotl supports FA2, FA3, and FA4. The best available version is used
|
||||
automatically based on your installed packages and GPU.
|
||||
|
||||
```yaml
|
||||
attn_implementation: flash_attention_2 # or flash_attention_3
|
||||
```
|
||||
|
||||
#### Flash Attention 2
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
|
||||
@@ -39,20 +82,20 @@ Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
### Flash Attention 3
|
||||
#### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/hopper
|
||||
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
#### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
|
||||
is true and FA4 is importable.
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
@@ -63,7 +106,6 @@ Or from source:
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
@@ -86,93 +128,113 @@ and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
Requirements: ROCm 6.0 and above. See
|
||||
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
## Flex Attention
|
||||
|
||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
||||
### Flex Attention
|
||||
|
||||
```yaml
|
||||
flex_attention: true
|
||||
|
||||
# recommended
|
||||
torch_compile: true
|
||||
attn_implementation: flex_attention
|
||||
torch_compile: true # recommended
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
|
||||
|
||||
We recommend using latest stable version of PyTorch for best performance.
|
||||
### SageAttention
|
||||
|
||||
:::
|
||||
|
||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
||||
|
||||
## SageAttention
|
||||
|
||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
||||
Requirements: Ampere, Ada, or Hopper GPUs.
|
||||
|
||||
```yaml
|
||||
sage_attention: true
|
||||
attn_implementation: sage
|
||||
```
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
```bash
|
||||
pip install sageattention==2.2.0 --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
|
||||
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## xFormers
|
||||
### xFormers
|
||||
|
||||
```yaml
|
||||
xformers_attention: true
|
||||
attn_implementation: xformers
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
We recommend using with Turing GPUs or below (such as on Colab).
|
||||
Recommended for Turing GPUs or below (e.g. Colab T4).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
||||
|
||||
## Shifted Sparse Attention
|
||||
### Shifted Sparse Attention
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
||||
Planned for deprecation. Prefer one of the backends above.
|
||||
|
||||
:::
|
||||
|
||||
Requirements: LLaMA model architecture
|
||||
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
|
||||
patched to implement shifted-sparse attention. Does not support sample packing.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
s2_attention: true
|
||||
attn_implementation: s2
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
### FP8
|
||||
|
||||
No sample packing support!
|
||||
torchao low-precision attention. Loaded as SDPA and patched post-load.
|
||||
|
||||
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
|
||||
flash-attn with FA3. KV caching must be disabled.
|
||||
|
||||
```yaml
|
||||
attn_implementation: fp8
|
||||
```
|
||||
|
||||
### Hub kernels
|
||||
|
||||
```yaml
|
||||
attn_implementation: kernels-community/flash-attn3
|
||||
```
|
||||
|
||||
Passed through to `transformers`; axolotl does not install the kernel itself.
|
||||
For recognized hub paths the capability flags are set automatically; for
|
||||
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
|
||||
`attn_uses_flash_lib=False`).
|
||||
|
||||
## Migrating from legacy boolean flags
|
||||
|
||||
The following legacy config fields are **deprecated** and will be removed in a
|
||||
future release. Each emits a `DeprecationWarning` when set and is stripped from
|
||||
the validated config.
|
||||
|
||||
| Legacy | Canonical |
|
||||
|---|---|
|
||||
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
|
||||
| `sdp_attention: true` | `attn_implementation: sdpa` |
|
||||
| `xformers_attention: true` | `attn_implementation: xformers` |
|
||||
| `flex_attention: true` | `attn_implementation: flex_attention` |
|
||||
| `sage_attention: true` | `attn_implementation: sage` |
|
||||
| `s2_attention: true` | `attn_implementation: s2` |
|
||||
| `eager_attention: true` | `attn_implementation: eager` |
|
||||
|
||||
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
|
||||
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
Existing example configs under `examples/` still use the legacy flags. They
|
||||
continue to work with a deprecation warning; they will be migrated in a
|
||||
follow-up pass.
|
||||
|
||||
:::
|
||||
|
||||
@@ -76,8 +76,9 @@ datasets:
|
||||
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
#### Remote Hosts
|
||||
@@ -208,17 +209,17 @@ cd axolotl
|
||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
|
||||
>[!Tip]
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
|
||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
uv sync --extra flash-attn --extra deepspeed --group dev --group test
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
### Attach To Container
|
||||
|
||||
@@ -6,23 +6,30 @@ format:
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
This section describes the different Docker images that are released by AxolotlAI at
|
||||
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
Each image below is available in a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with
|
||||
a relocatable venv (`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||
(e.g. `axolotlai/axolotl-base-uv`). Tags follow the same format. We recommend the uv images for new deployments.
|
||||
:::
|
||||
|
||||
## Base
|
||||
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
|
||||
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-base
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
|
||||
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
|
||||
|
||||
#### Tags format
|
||||
|
||||
@@ -32,8 +39,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.8.0`
|
||||
- `main-base-py3.11-cu128-2.9.1`
|
||||
- `main-base-py3.12-cu128-2.10.0`
|
||||
- `main-base-py3.12-cu130-2.9.1`
|
||||
- `main-base-py3.12-cu130-2.10.0`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -41,11 +50,10 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
|
||||
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
|
||||
|
||||
#### Tags format {#sec-main-tags}
|
||||
|
||||
@@ -53,7 +61,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
# on push to main
|
||||
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
|
||||
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
|
||||
main-latest
|
||||
|
||||
# nightly build
|
||||
@@ -71,11 +79,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.8.0`
|
||||
- `main-py3.11-cu128-2.9.1`
|
||||
- `main-py3.12-cu128-2.10.0`
|
||||
- `main-py3.12-cu130-2.9.1`
|
||||
- `main-py3.12-cu130-2.10.0`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `main-20260315-py3.11-cu128-2.9.1`
|
||||
- `0.12.0`
|
||||
|
||||
## Cloud
|
||||
@@ -90,11 +99,10 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
|
||||
|
||||
#### Image
|
||||
|
||||
```
|
||||
axolotlai/axolotl-cloud
|
||||
```
|
||||
|
||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
||||
| Variant | Image | Docker Hub |
|
||||
|---------|-------|------------|
|
||||
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
|
||||
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
|
||||
|
||||
#### Tags format
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
|
||||
max_steps: 20
|
||||
learning_rate: 5.0e-6
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
output_dir: ./outputs/ebft-quickstart
|
||||
```
|
||||
@@ -304,7 +304,7 @@ lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flex_attention: true
|
||||
attn_implementation: flex_attention
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required with flex_attention
|
||||
|
||||
@@ -154,7 +154,7 @@ lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
|
||||
bf16: true
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -15,64 +15,30 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- PyTorch ≥2.9.0
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
## Installation {#sec-installation}
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
### Quick Install {#sec-uv}
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
|
||||
|
||||
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
|
||||
installed) in order not to clobber it, and so that we set the correct version of
|
||||
dependencies that are specific to the PyTorch version or other installed
|
||||
co-dependencies.
|
||||
|
||||
### uv Installation {#sec-uv}
|
||||
|
||||
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
|
||||
|
||||
Install uv if not already installed
|
||||
Install uv if not already installed:
|
||||
```{.bash}
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
```
|
||||
|
||||
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
|
||||
then create the venv and activate
|
||||
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||
```{.bash}
|
||||
export UV_TORCH_BACKEND=cu126
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Install PyTorch
|
||||
- PyTorch 2.6.0 recommended
|
||||
```{.bash}
|
||||
uv pip install packaging setuptools wheel
|
||||
uv pip install torch==2.6.0
|
||||
uv pip install awscli pydantic
|
||||
```
|
||||
|
||||
Install axolotl from PyPi
|
||||
```{.bash}
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
|
||||
|
||||
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
### Edge/Development Build {#sec-edge-build}
|
||||
@@ -82,14 +48,17 @@ For the latest features between releases:
|
||||
```{.bash}
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv sync --extra flash-attn --extra deepspeed
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
`uv sync` creates a `.venv`, installs exact pinned versions from `uv.lock`, and sets up an editable install automatically.
|
||||
|
||||
### Docker {#sec-docker}
|
||||
|
||||
```{.bash}
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
|
||||
For development with Docker:
|
||||
@@ -106,12 +75,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
--ulimit memlock=-1 --ulimit stack=67108864 \
|
||||
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
||||
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
||||
axolotlai/axolotl:main-latest
|
||||
axolotlai/axolotl-uv:main-latest
|
||||
```
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
@@ -122,7 +91,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
|
||||
|
||||
For providers supporting Docker:
|
||||
|
||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
||||
- Use `axolotlai/axolotl-cloud-uv:main-latest`
|
||||
- Available on:
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
||||
@@ -141,7 +110,7 @@ For providers supporting Docker:
|
||||
### macOS {#sec-macos}
|
||||
|
||||
```{.bash}
|
||||
pip3 install --no-build-isolation -e '.'
|
||||
uv pip install --no-build-isolation -e '.'
|
||||
```
|
||||
|
||||
See @sec-troubleshooting for Mac-specific issues.
|
||||
@@ -152,21 +121,44 @@ See @sec-troubleshooting for Mac-specific issues.
|
||||
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
:::
|
||||
|
||||
## Environment Managers {#sec-env-managers}
|
||||
## Migrating from pip to uv {#sec-migrating}
|
||||
|
||||
### Conda/Pip venv {#sec-conda}
|
||||
If you have an existing pip-based Axolotl installation, you can migrate to uv:
|
||||
|
||||
1. Install Python ≥3.11
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
```{.bash}
|
||||
hf auth login
|
||||
```
|
||||
```{.bash}
|
||||
# Install uv
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
|
||||
# Create a fresh venv (recommended for a clean start)
|
||||
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
|
||||
# Reinstall axolotl
|
||||
uv pip install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
## Using pip (Alternative) {#sec-pip}
|
||||
|
||||
If you are unable to install uv, you can still use pip directly.
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have PyTorch installed before installing Axolotl with pip.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
For editable/development installs:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
|
||||
@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
|
||||
|
||||
Using an optimized attention implementation is critical for training speed.
|
||||
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
|
||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
|
||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
|
||||
|
||||
*Note: You should only enable one attention backend.*
|
||||
See [Attention](attention.qmd) for the full list of backends and the canonical values.
|
||||
|
||||
### LoRA Optimizations
|
||||
|
||||
|
||||
@@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
rl: dpo
|
||||
dpo_loss_type: ["ipo"]
|
||||
```
|
||||
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
|
||||
|
||||
### ORPO
|
||||
|
||||
@@ -1145,8 +1147,7 @@ datasets:
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
flash_attention: false
|
||||
flex_attention: true # Strided mode uses flex_attention
|
||||
attn_implementation: flex_attention # Strided mode uses flex_attention
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for flex_attention
|
||||
|
||||
@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
|
||||
|
||||
## Limitations
|
||||
|
||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
|
||||
- May have a small performance overhead due to communication between GPUs
|
||||
|
||||
## Example
|
||||
|
||||
@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
|
||||
Reduces attention memory from O(n^2) to O(n):
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
```
|
||||
|
||||
### Step 6: Offload with DeepSpeed
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
---
|
||||
title: "Unsloth"
|
||||
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||
---
|
||||
|
||||
### Overview
|
||||
|
||||
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||
standard industry baselines.
|
||||
|
||||
::: {.callout-important}
|
||||
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
|
||||
|
||||
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
|
||||
:::
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
The following will install the correct unsloth and extras from source.
|
||||
|
||||
```bash
|
||||
python scripts/unsloth_install.py | sh
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||
|
||||
Our unsloth integration is currently limited to the following model architectures:
|
||||
- llama
|
||||
|
||||
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
||||
```yaml
|
||||
unsloth_lora_mlp: true
|
||||
unsloth_lora_qkv: true
|
||||
unsloth_lora_o: true
|
||||
```
|
||||
|
||||
These options are composable and can be used with multi-gpu finetuning
|
||||
```yaml
|
||||
unsloth_cross_entropy_loss: true
|
||||
unsloth_rms_norm: true
|
||||
unsloth_rope: true
|
||||
```
|
||||
|
||||
### Limitations
|
||||
|
||||
- Single GPU only; e.g. no multi-gpu support
|
||||
- No deepspeed or FSDP support (requires multi-gpu)
|
||||
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
||||
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
||||
- No MoE support.
|
||||
@@ -15,8 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run one of the finetuning examples below.
|
||||
@@ -35,7 +34,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
|
||||
**LFM2-MoE**
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||
|
||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||
@@ -45,7 +44,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
```bash
|
||||
pip uninstall -y causal-conv1d
|
||||
uv pip uninstall causal-conv1d
|
||||
```
|
||||
|
||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
@@ -39,7 +39,7 @@ tf32: true
|
||||
gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -48,7 +48,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -50,8 +50,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_steps: 100
|
||||
saves_per_epoch: 1
|
||||
|
||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_steps: 100
|
||||
saves_per_epoch: 1
|
||||
|
||||
@@ -15,8 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
@@ -31,7 +30,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
# For those using our Docker image, use the below path.
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
|
||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||
@@ -67,7 +66,7 @@ If those didn't help, please try the below solutions:
|
||||
1. Pass env for CMAKE and try install again:
|
||||
|
||||
```bash
|
||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
2. Git clone the repo and manually hardcode python path:
|
||||
@@ -92,7 +91,7 @@ If those didn't help, please try the below solutions:
|
||||
```
|
||||
|
||||
```bash
|
||||
pip3 install . --no-build-isolation --no-deps
|
||||
uv pip install . --no-build-isolation --no-deps
|
||||
```
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -17,8 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
uv pip install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -59,8 +59,7 @@ gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
flash_attention: true
|
||||
sdp_attention:
|
||||
attn_implementation: flash_attention_2
|
||||
flash_optimum:
|
||||
|
||||
gptq_groupsize:
|
||||
|
||||
@@ -39,8 +39,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -43,8 +43,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -73,8 +73,7 @@ early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -40,8 +40,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -36,8 +36,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -37,8 +37,7 @@ bf16: auto
|
||||
tf32: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -39,7 +39,6 @@ bf16: auto
|
||||
tf32: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -39,7 +39,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -40,7 +40,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -47,7 +47,6 @@ tf32: false
|
||||
gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -47,7 +47,6 @@ tf32: false
|
||||
gradient_checkpointing: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -40,7 +40,6 @@ bf16: auto
|
||||
tf32: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 5
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -38,7 +38,6 @@ tf32: true
|
||||
gradient_checkpointing:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -44,7 +44,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -47,7 +47,6 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: false
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 0
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -36,7 +36,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -47,7 +47,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -71,8 +71,7 @@ early_stopping_patience: 3
|
||||
resume_from_checkpoint:
|
||||
auto_resume_from_checkpoints: true
|
||||
logging_steps: 1
|
||||
xformers_attention: true
|
||||
flash_attention:
|
||||
attn_implementation: xformers
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -10,7 +10,7 @@ load_in_4bit: true
|
||||
sequence_len: 1024
|
||||
bf16: auto
|
||||
tf32: false
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
special_tokens:
|
||||
bos_token: "<|startoftext|>"
|
||||
eos_token: "<|endoftext|>"
|
||||
|
||||
@@ -48,7 +48,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -45,7 +45,7 @@ tf32: true
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
|
||||
@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
|
||||
@@ -16,8 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
uv pip install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
@@ -51,7 +51,7 @@ tf32: false
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
scaling_softmax: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
|
||||
@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
|
||||
@@ -26,7 +26,7 @@ output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1 # must be 1 when using context parallel
|
||||
|
||||
@@ -65,8 +65,7 @@ early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
@@ -46,7 +46,7 @@ lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -66,7 +66,7 @@ lora_target_linear: true
|
||||
|
||||
# --- Hardware ---
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -47,8 +47,7 @@ lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
flex_attention: true # fused flex_attention kernel compiles itself; don't set torch_compile: true
|
||||
attn_implementation: flex_attention
|
||||
# (full-model compile conflicts with gradient checkpointing + flex_attention)
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -46,7 +46,6 @@ lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT overrides to flex_attention (or eager fallback) at runtime
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -48,7 +48,6 @@ lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
torch_dtype: bfloat16
|
||||
flash_attention: false
|
||||
gradient_checkpointing: true
|
||||
torch_compile: true
|
||||
gradient_checkpointing_kwargs:
|
||||
|
||||
@@ -41,7 +41,6 @@ warmup_steps: 10
|
||||
weight_decay: 0.01
|
||||
|
||||
bf16: auto
|
||||
flash_attention: false # strided EBFT uses flex_attention at runtime
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
|
||||
@@ -72,7 +72,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -63,7 +63,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -68,7 +68,7 @@ lora_dropout: 0.0
|
||||
lora_target_modules: ".*\\.layers\\.(3|7|11|15|19|23|27|31)\\.self_attn\\.(q|k|v|o)_proj|.*\\.mlp\\.(gate|up|down)_proj"
|
||||
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -62,7 +62,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
@@ -61,7 +61,7 @@ gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user