Compare commits
1 Commits
v0.16.1
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0aa7c72c59 |
@@ -1,41 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
_axolotl_completions() {
|
||||
local cur prev
|
||||
COMPREPLY=()
|
||||
cur="${COMP_WORDS[COMP_CWORD]}"
|
||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
||||
|
||||
# If we're completing the first argument (the command)
|
||||
if [[ $COMP_CWORD -eq 1 ]]; then
|
||||
mapfile -t COMPREPLY < <(compgen -W "delinearize-llama4 fetch lm-eval merge-sharded-fsdp-weights quantize vllm-serve evaluate inference merge-lora preprocess train" -- "$cur")
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Commands that should complete with directories and YAML files
|
||||
local -a yaml_commands=("merge-sharded-fsdp-weights" "quantize" "vllm-serve" "evaluate" "inference" "merge-lora" "preprocess" "train")
|
||||
|
||||
# Check if previous word is in our list
|
||||
if [[ " ${yaml_commands[*]} " =~ (^|[[:space:]])$prev($|[[:space:]]) ]]; then
|
||||
# Use filename completion which handles directories properly
|
||||
compopt -o filenames
|
||||
mapfile -t COMPREPLY < <(compgen -f -- "$cur")
|
||||
|
||||
# Filter to only include directories and YAML files
|
||||
local -a filtered=()
|
||||
for item in "${COMPREPLY[@]}"; do
|
||||
if [[ -d "$item" ]] || [[ "$item" == *.yaml ]] || [[ "$item" == *.yml ]]; then
|
||||
filtered+=("$item")
|
||||
fi
|
||||
done
|
||||
COMPREPLY=("${filtered[@]}")
|
||||
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Default: no completion
|
||||
return 0
|
||||
}
|
||||
|
||||
# Remove the -o nospace option - let filenames handle it
|
||||
complete -F _axolotl_completions axolotl
|
||||
2
.bandit
2
.bandit
@@ -1,3 +1,3 @@
|
||||
[bandit]
|
||||
exclude = tests
|
||||
skips = B101,B615,B102,B110
|
||||
skips = B101
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
|
||||
language: "en-US"
|
||||
early_access: false
|
||||
reviews:
|
||||
profile: "chill"
|
||||
request_changes_workflow: false
|
||||
high_level_summary: true
|
||||
review_status: true
|
||||
collapse_walkthrough: true
|
||||
poem: false
|
||||
sequence_diagrams: false
|
||||
auto_review:
|
||||
enabled: true
|
||||
drafts: false
|
||||
auto_incremental_review: false
|
||||
chat:
|
||||
auto_reply: true
|
||||
14
.coveragerc
14
.coveragerc
@@ -1,14 +0,0 @@
|
||||
[run]
|
||||
source = axolotl
|
||||
omit =
|
||||
*/tests/*
|
||||
setup.py
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
def __repr__
|
||||
raise NotImplementedError
|
||||
if __name__ == .__main__.:
|
||||
pass
|
||||
raise ImportError
|
||||
5
.flake8
Normal file
5
.flake8
Normal file
@@ -0,0 +1,5 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
|
||||
select = C,E,F,W,B,B950
|
||||
extend-ignore = E203, E501, W503
|
||||
16
.github/CONTRIBUTING.md
vendored
16
.github/CONTRIBUTING.md
vendored
@@ -57,23 +57,11 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
|
||||
5. Push your branch to your fork on GitHub.
|
||||
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
||||
|
||||
#### Skipping CI Checks
|
||||
|
||||
You can skip certain CI checks by including specific keywords in your commit messages:
|
||||
|
||||
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
|
||||
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
|
||||
|
||||
## Style Guidelines
|
||||
|
||||
### Code Style
|
||||
|
||||
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
|
||||
Use the pre-commit linter to ensure that your code is formatted consistently.
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
|
||||
### Commit Messages
|
||||
|
||||
@@ -83,6 +71,6 @@ Write clear and concise commit messages that briefly describe the changes made i
|
||||
|
||||
- [GitHub Help](https://help.github.com/)
|
||||
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
|
||||
- [Ruff](https://docs.astral.sh/ruff/)
|
||||
- [{codestyle}]({URLofCodestyle})
|
||||
|
||||
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!
|
||||
|
||||
6
.github/FUNDING.yml
vendored
6
.github/FUNDING.yml
vendored
@@ -1,13 +1,13 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: # Replace with a single Open Collective username
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
ko_fi: axolotl_ai # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
otechie: # Replace with a single Otechie username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480¢erImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
||||
|
||||
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
5
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -15,11 +15,6 @@
|
||||
<!--- Include details of your testing environment, tests ran to see how -->
|
||||
<!--- your change affects other areas of the code, etc. -->
|
||||
|
||||
## AI Usage Disclaimer
|
||||
|
||||
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
|
||||
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
|
||||
|
||||
## Screenshots (if appropriate)
|
||||
|
||||
## Types of changes
|
||||
|
||||
235
.github/workflows/base.yml
vendored
235
.github/workflows/base.yml
vendored
@@ -5,110 +5,59 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- 'docker/Dockerfile-base'
|
||||
- 'docker/Dockerfile-uv-base'
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docker/Dockerfile-base'
|
||||
- 'docker/Dockerfile-uv-base'
|
||||
- 'Dockerfile-base'
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-base:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
runs-on: axolotl-gpu-runner
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
pytorch: nightly
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: next
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.11"
|
||||
# pytorch: nightly
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base-nightly"
|
||||
# # "next" is for release candidates of pytorch
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.11"
|
||||
# pytorch: next
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base-next"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -117,136 +66,20 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-base
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
CUDNN_VERSION=${{ matrix.cudnn_version }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTHON_VERSION=${{ matrix.python_version }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||
build-base-uv:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
timeout-minutes: 480
|
||||
runs-on: ubuntu-latest-m
|
||||
env:
|
||||
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-uv-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-base-uv
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platforms }}
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
5
.github/workflows/docs.yml
vendored
5
.github/workflows/docs.yml
vendored
@@ -12,9 +12,6 @@ jobs:
|
||||
build-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Quarto
|
||||
@@ -26,7 +23,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e .
|
||||
python3 -m pip install -e . --no-deps
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
- name: Publish to GitHub Pages (and render)
|
||||
|
||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@@ -3,24 +3,18 @@ on:
|
||||
# check on PRs, and manual triggers
|
||||
merge_group:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/*.yml'
|
||||
- "*.[q]md"
|
||||
- "examples/**/*.y[a]?ml"
|
||||
- ".pre-commit-config.yaml"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
|
||||
262
.github/workflows/main.yml
vendored
262
.github/workflows/main.yml
vendored
@@ -8,9 +8,6 @@ on:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -18,43 +15,22 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -64,6 +40,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
@@ -80,13 +57,11 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
@@ -95,128 +70,29 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-uv:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||
- name: Build and export to Docker
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
||||
file: ./docker/Dockerfile-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -226,6 +102,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
@@ -241,7 +118,6 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
@@ -252,100 +128,18 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-uv:
|
||||
needs: build-axolotl-uv
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-cloud-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-no-tmux:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -355,6 +149,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud-term
|
||||
axolotlai/axolotl-cloud-term
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
@@ -370,7 +165,6 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
|
||||
55
.github/workflows/multi-gpu-e2e.yml
vendored
55
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,14 +3,11 @@ name: docker-multigpu-tests-biweekly
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/**.py'
|
||||
- 'tests/e2e/multigpu/*.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'scripts/cutcrossentropy_install.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
@@ -20,40 +17,34 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras: # no vllm support for 2.4.1
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -66,7 +57,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -75,9 +66,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run -m cicd.multigpu
|
||||
modal run cicd.multigpu
|
||||
|
||||
37
.github/workflows/nightlies.yml
vendored
37
.github/workflows/nightlies.yml
vendored
@@ -5,9 +5,6 @@ on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -15,10 +12,20 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -29,6 +36,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
@@ -62,10 +70,20 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -76,6 +94,7 @@ jobs:
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
|
||||
13
.github/workflows/precommit-autoupdate.yml
vendored
13
.github/workflows/precommit-autoupdate.yml
vendored
@@ -2,11 +2,9 @@ name: Pre-commit auto-update
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 1 * *' # Run monthly
|
||||
- cron: '0 0 * * 0' # Run weekly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
auto-update:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -27,6 +25,7 @@ jobs:
|
||||
pre-commit autoupdate
|
||||
if [[ -n $(git status --porcelain) ]]; then
|
||||
echo "changes=true" >> $GITHUB_OUTPUT
|
||||
git diff .pre-commit-config.yaml > pre-commit-update.diff
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
@@ -40,3 +39,11 @@ jobs:
|
||||
commit-message: "chore: update pre-commit hooks"
|
||||
body: |
|
||||
Automated PR to update pre-commit hooks to their latest versions.
|
||||
|
||||
<details>
|
||||
<summary>Changes:</summary>
|
||||
|
||||
```diff
|
||||
${{ steps.update.outputs.diff }}
|
||||
```
|
||||
</details>
|
||||
|
||||
77
.github/workflows/preview-docs.yml
vendored
77
.github/workflows/preview-docs.yml
vendored
@@ -1,77 +0,0 @@
|
||||
name: Preview
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
# Run the workflow only when one of these files changes
|
||||
paths:
|
||||
- '**/*.md' # any Markdown file
|
||||
- '**/*.qmd' # any Quarto file
|
||||
- '_quarto.yml'
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- src/axolotl/utils/schemas/**.py
|
||||
- .github/workflows/preview-docs.yml
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
preview:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
- name: Set up Quarto
|
||||
uses: quarto-dev/quarto-actions/setup@v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e .
|
||||
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
|
||||
- name: Quarto render
|
||||
run: quarto render
|
||||
|
||||
- name: Netlify Publish
|
||||
uses: nwtgck/actions-netlify@v3.0
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
id: netlify
|
||||
with:
|
||||
publish-dir: './_site'
|
||||
enable-pull-request-comment: false
|
||||
enable-github-deployment: false
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
deploy-message: "Deployed On Netlify"
|
||||
github-deployment-environment: 'preview'
|
||||
github-deployment-description: 'Preview Deployment'
|
||||
env:
|
||||
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
|
||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||
|
||||
- name: Update PR with preview link
|
||||
if: ${{ steps.netlify.outcome == 'success' }}
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
message: |
|
||||
📖 **Documentation Preview**: ${{ steps.netlify.outputs.deploy-url }}
|
||||
|
||||
Deployed on Netlify from commit ${{ github.event.pull_request.head.sha }}
|
||||
15
.github/workflows/pypi.yml
vendored
15
.github/workflows/pypi.yml
vendored
@@ -3,11 +3,9 @@ name: publish pypi
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
@@ -30,8 +28,7 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
@@ -43,17 +40,17 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging==26.0
|
||||
pip3 install wheel packaging==23.2
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
|
||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||
|
||||
- name: Update version in VERSION file
|
||||
- name: Update version in setup.py
|
||||
run: |
|
||||
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
|
||||
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||
|
||||
- name: Build a source dist
|
||||
run: |
|
||||
|
||||
123
.github/workflows/tests-nightly.yml
vendored
123
.github/workflows/tests-nightly.yml
vendored
@@ -3,13 +3,6 @@ on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '.github/workflows/tests-nightly.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -25,37 +18,29 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
prime-cdn-s3-cache:
|
||||
name: Prefetch S3 once to prime the CDN cache
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -66,11 +51,11 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Update requirements.txt
|
||||
run: |
|
||||
@@ -96,11 +81,15 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -110,33 +99,33 @@ jobs:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 60
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -148,7 +137,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -157,53 +146,7 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
docker-e2e-multigpu-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
needs: [pre-commit, pytest, docker-e2e-tests]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 2
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.multigpu
|
||||
|
||||
280
.github/workflows/tests.yml
vendored
280
.github/workflows/tests.yml
vendored
@@ -13,7 +13,6 @@ on:
|
||||
- 'cicd/cicd.sh'
|
||||
- 'cicd/Dockerfile.jinja'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
@@ -28,17 +27,10 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
@@ -49,46 +41,29 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
prime-cdn-s3-cache:
|
||||
name: Prefetch S3 once to prime the CDN cache
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.12", "3.14"]
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
exclude:
|
||||
- python_version: "3.14"
|
||||
pytorch_version: "2.9.1"
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -99,24 +74,20 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
@@ -127,62 +98,50 @@ jobs:
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
df -h
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1
|
||||
matrix:
|
||||
python_version: ["3.12", "3.14"]
|
||||
pytorch_version: ["2.9.1", "2.10.0"]
|
||||
exclude:
|
||||
- python_version: "3.14"
|
||||
pytorch_version: "2.9.1"
|
||||
timeout-minutes: 30
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -193,25 +152,21 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python -m build --no-isolation --sdist
|
||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
||||
pip3 install --no-build-isolation dist/axolotl*.tar.gz
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
@@ -221,67 +176,44 @@ jobs:
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
run: huggingface-cli scan-cache
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache ls
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit]
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
skip: ${{ steps.compute.outputs.skip }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
id: compute
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
script: |
|
||||
const token = /\[skip-e2e\]/i;
|
||||
let msg = '';
|
||||
if (context.eventName === 'push') {
|
||||
msg = context.payload.head_commit?.message || '';
|
||||
} else if (context.eventName === 'pull_request') {
|
||||
const { owner, repo } = context.repo;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const commits = await github.paginate(
|
||||
github.rest.pulls.listCommits,
|
||||
{ owner, repo, pull_number: prNumber, per_page: 100 }
|
||||
);
|
||||
msg = commits.at(-1)?.commit?.message || '';
|
||||
}
|
||||
const title = context.payload.pull_request?.title || '';
|
||||
const body = context.payload.pull_request?.body || '';
|
||||
const skip = token.test(msg) || token.test(title) || token.test(body);
|
||||
core.setOutput('skip', String(skip));
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
# Run this job first as a gate for running the remainder of the test matrix
|
||||
if: >
|
||||
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
needs: [pre-commit, pytest]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -292,7 +224,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -302,47 +234,33 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: >
|
||||
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
# Only run the remainder of the matrix if the first e2e check passed;
|
||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -353,7 +271,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -363,50 +281,6 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-cleanup:
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [docker-e2e-tests]
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -190,6 +190,3 @@ out/
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
# scm auto-versioning
|
||||
src/axolotl/_version.py
|
||||
|
||||
4
.isort.cfg
Normal file
4
.isort.cfg
Normal file
@@ -0,0 +1,4 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb,comet_ml
|
||||
known_local_folder=src,tests
|
||||
@@ -3,21 +3,31 @@ default_language_version:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v6.0.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.8
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.1.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.6
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.19.1
|
||||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -26,7 +36,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.9.4
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
15
.pylintrc
Normal file
15
.pylintrc
Normal file
@@ -0,0 +1,15 @@
|
||||
[MASTER]
|
||||
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=numpy.*, torch.*
|
||||
|
||||
|
||||
[pylint.messages_control]
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-positional-arguments, possibly-used-before-assignment
|
||||
161
.runpod/.gitignore
vendored
161
.runpod/.gitignore
vendored
@@ -1,161 +0,0 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
pod/scripts/config.yaml
|
||||
@@ -1,19 +0,0 @@
|
||||
FROM axolotlai/axolotl-cloud:main-py3.11-cu124-2.6.0
|
||||
|
||||
COPY .runpod/requirements.txt /requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install --upgrade pip && \
|
||||
python3 -m pip install --upgrade -r /requirements.txt
|
||||
|
||||
# Environment settings
|
||||
ARG BASE_VOLUME="/runpod-volume"
|
||||
ENV BASE_VOLUME=$BASE_VOLUME
|
||||
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
|
||||
COPY .runpod/src /src
|
||||
|
||||
WORKDIR /src
|
||||
CMD ["python3", "/src/handler.py"]
|
||||
@@ -1,335 +0,0 @@
|
||||
<h1>LLM Post Training- Full fine-tune, LoRA, QLoRa etc. Llama/Mistral/Gemma and more</h1>
|
||||
|
||||
# Configuration Options
|
||||
|
||||
This document outlines all available configuration options for training models. The configuration can be provided as a JSON request.
|
||||
|
||||
## Usage
|
||||
|
||||
You can use these configuration Options:
|
||||
|
||||
1. As a JSON request body:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"user_id": "user",
|
||||
"model_id": "model-name",
|
||||
"run_id": "run-id",
|
||||
"credentials": {
|
||||
"wandb_api_key": "", # add your Weights & biases key. TODO: you will be able to set this in Enviornment variables.
|
||||
"hf_token": "", # add your HF_token. TODO: you will be able to set this in Enviornment variables.
|
||||
},
|
||||
"args": {
|
||||
"base_model": "NousResearch/Llama-3.2-1B",
|
||||
// ... other options
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Model Configuration
|
||||
|
||||
| Option | Description | Default |
|
||||
| ------------------- | --------------------------------------------------------------------------------------------- | -------------------- |
|
||||
| `base_model` | Path to the base model (local or HuggingFace) | Required |
|
||||
| `base_model_config` | Configuration path for the base model | Same as base_model |
|
||||
| `revision_of_model` | Specific model revision from HuggingFace hub | Latest |
|
||||
| `tokenizer_config` | Custom tokenizer configuration path | Optional |
|
||||
| `model_type` | Type of model to load | AutoModelForCausalLM |
|
||||
| `tokenizer_type` | Type of tokenizer to use | AutoTokenizer |
|
||||
| `hub_model_id` | Repository ID where the model will be pushed on Hugging Face Hub (format: username/repo-name) | Optional |
|
||||
|
||||
## Model Family Identification
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------------------- | ------- | ------------------------------ |
|
||||
| `is_falcon_derived_model` | `false` | Whether model is Falcon-based |
|
||||
| `is_llama_derived_model` | `false` | Whether model is LLaMA-based |
|
||||
| `is_qwen_derived_model` | `false` | Whether model is Qwen-based |
|
||||
| `is_mistral_derived_model` | `false` | Whether model is Mistral-based |
|
||||
|
||||
## Model Configuration Overrides
|
||||
|
||||
| Option | Default | Description |
|
||||
| ----------------------------------------------- | ---------- | ---------------------------------- |
|
||||
| `overrides_of_model_config.rope_scaling.type` | `"linear"` | RoPE scaling type (linear/dynamic) |
|
||||
| `overrides_of_model_config.rope_scaling.factor` | `1.0` | RoPE scaling factor |
|
||||
|
||||
### Model Loading Options
|
||||
|
||||
| Option | Description | Default |
|
||||
| -------------- | ----------------------------- | ------- |
|
||||
| `load_in_8bit` | Load model in 8-bit precision | false |
|
||||
| `load_in_4bit` | Load model in 4-bit precision | false |
|
||||
| `bf16` | Use bfloat16 precision | false |
|
||||
| `fp16` | Use float16 precision | false |
|
||||
| `tf32` | Use tensor float 32 precision | false |
|
||||
|
||||
## Memory and Device Settings
|
||||
|
||||
| Option | Default | Description |
|
||||
| ------------------ | --------- | ----------------------- |
|
||||
| `gpu_memory_limit` | `"20GiB"` | GPU memory limit |
|
||||
| `lora_on_cpu` | `false` | Load LoRA on CPU |
|
||||
| `device_map` | `"auto"` | Device mapping strategy |
|
||||
| `max_memory` | `null` | Max memory per device |
|
||||
|
||||
## Training Hyperparameters
|
||||
|
||||
| Option | Default | Description |
|
||||
| ----------------------------- | --------- | --------------------------- |
|
||||
| `gradient_accumulation_steps` | `1` | Gradient accumulation steps |
|
||||
| `micro_batch_size` | `2` | Batch size per GPU |
|
||||
| `eval_batch_size` | `null` | Evaluation batch size |
|
||||
| `num_epochs` | `4` | Number of training epochs |
|
||||
| `warmup_steps` | `100` | Warmup steps |
|
||||
| `warmup_ratio` | `0.05` | Warmup ratio |
|
||||
| `learning_rate` | `0.00003` | Learning rate |
|
||||
| `lr_quadratic_warmup` | `false` | Quadratic warmup |
|
||||
| `logging_steps` | `null` | Logging frequency |
|
||||
| `eval_steps` | `null` | Evaluation frequency |
|
||||
| `evals_per_epoch` | `null` | Evaluations per epoch |
|
||||
| `save_strategy` | `"epoch"` | Checkpoint saving strategy |
|
||||
| `save_steps` | `null` | Saving frequency |
|
||||
| `saves_per_epoch` | `null` | Saves per epoch |
|
||||
| `save_total_limit` | `null` | Maximum checkpoints to keep |
|
||||
| `max_steps` | `null` | Maximum training steps |
|
||||
|
||||
### Dataset Configuration
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4 # HuggingFace dataset or TODO: You will be able to add the local path.
|
||||
type: alpaca # Format type (alpaca, gpteacher, oasst, etc.)
|
||||
ds_type: json # Dataset type
|
||||
data_files: path/to/data # Source data files
|
||||
train_on_split: train # Dataset split to use
|
||||
```
|
||||
|
||||
## Chat Template Settings
|
||||
|
||||
| Option | Default | Description |
|
||||
| ------------------------ | -------------------------------- | ---------------------- |
|
||||
| `chat_template` | `"tokenizer_default"` | Chat template type |
|
||||
| `chat_template_jinja` | `null` | Custom Jinja template |
|
||||
| `default_system_message` | `"You are a helpful assistant."` | Default system message |
|
||||
|
||||
## Dataset Processing
|
||||
|
||||
| Option | Default | Description |
|
||||
| --------------------------------- | -------------------------- | ----------------------------------- |
|
||||
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
|
||||
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
|
||||
| `dataset_num_proc` | `4` | Number of preprocessing processes |
|
||||
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
|
||||
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
|
||||
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
|
||||
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
|
||||
|
||||
## LoRA Configuration
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------------------- | ---------------------- | ------------------------------ |
|
||||
| `adapter` | `"lora"` | Adapter type (lora/qlora) |
|
||||
| `lora_model_dir` | `""` | Directory with pretrained LoRA |
|
||||
| `lora_r` | `8` | LoRA attention dimension |
|
||||
| `lora_alpha` | `16` | LoRA alpha parameter |
|
||||
| `lora_dropout` | `0.05` | LoRA dropout |
|
||||
| `lora_target_modules` | `["q_proj", "v_proj"]` | Modules to apply LoRA |
|
||||
| `lora_target_linear` | `false` | Target all linear modules |
|
||||
| `peft_layers_to_transform` | `[]` | Layers to transform |
|
||||
| `lora_modules_to_save` | `[]` | Modules to save |
|
||||
| `lora_fan_in_fan_out` | `false` | Fan in/out structure |
|
||||
|
||||
## Optimization Settings
|
||||
|
||||
| Option | Default | Description |
|
||||
| ------------------------- | ------- | -------------------------- |
|
||||
| `train_on_inputs` | `false` | Train on input prompts |
|
||||
| `group_by_length` | `false` | Group by sequence length |
|
||||
| `gradient_checkpointing` | `false` | Use gradient checkpointing |
|
||||
| `early_stopping_patience` | `3` | Early stopping patience |
|
||||
|
||||
## Learning Rate Scheduling
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------------------- | ---------- | -------------------- |
|
||||
| `lr_scheduler` | `"cosine"` | Scheduler type |
|
||||
| `lr_scheduler_kwargs` | `{}` | Scheduler parameters |
|
||||
| `cosine_min_lr_ratio` | `null` | Minimum LR ratio |
|
||||
| `cosine_constant_lr_ratio` | `null` | Constant LR ratio |
|
||||
| `lr_div_factor` | `null` | LR division factor |
|
||||
|
||||
## Optimizer Settings
|
||||
|
||||
| Option | Default | Description |
|
||||
| ---------------------- | ------------ | ------------------- |
|
||||
| `optimizer` | `"adamw_hf"` | Optimizer choice |
|
||||
| `optim_args` | `{}` | Optimizer arguments |
|
||||
| `optim_target_modules` | `[]` | Target modules |
|
||||
| `weight_decay` | `null` | Weight decay |
|
||||
| `adam_beta1` | `null` | Adam beta1 |
|
||||
| `adam_beta2` | `null` | Adam beta2 |
|
||||
| `adam_epsilon` | `null` | Adam epsilon |
|
||||
| `max_grad_norm` | `null` | Gradient clipping |
|
||||
|
||||
## Attention Implementations
|
||||
|
||||
| Option | Default | Description |
|
||||
| -------------------------- | ------- | ----------------------------- |
|
||||
| `flash_optimum` | `false` | Use better transformers |
|
||||
| `xformers_attention` | `false` | Use xformers |
|
||||
| `flash_attention` | `false` | Use flash attention |
|
||||
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
||||
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
||||
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
||||
| `sdp_attention` | `false` | Use scaled dot product |
|
||||
| `s2_attention` | `false` | Use shifted sparse attention |
|
||||
|
||||
## Tokenizer Modifications
|
||||
|
||||
| Option | Default | Description |
|
||||
| ---------------- | ------- | ---------------------------- |
|
||||
| `special_tokens` | - | Special tokens to add/modify |
|
||||
| `tokens` | `[]` | Additional tokens |
|
||||
|
||||
## Distributed Training
|
||||
|
||||
| Option | Default | Description |
|
||||
| ----------------------- | ------- | --------------------- |
|
||||
| `fsdp` | `null` | FSDP configuration |
|
||||
| `fsdp_config` | `null` | FSDP config options |
|
||||
| `deepspeed` | `null` | Deepspeed config path |
|
||||
| `ddp_timeout` | `null` | DDP timeout |
|
||||
| `ddp_bucket_cap_mb` | `null` | DDP bucket capacity |
|
||||
| `ddp_broadcast_buffers` | `null` | DDP broadcast buffers |
|
||||
|
||||
<details>
|
||||
<summary><h3>Example Configuration Request:</h3></summary>
|
||||
|
||||
Here's a complete example for fine-tuning a LLaMA model using LoRA:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": {
|
||||
"user_id": "user",
|
||||
"model_id": "llama-test",
|
||||
"run_id": "test-run",
|
||||
"credentials": {
|
||||
"wandb_api_key": "",
|
||||
"hf_token": ""
|
||||
},
|
||||
"args": {
|
||||
"base_model": "NousResearch/Llama-3.2-1B",
|
||||
"load_in_8bit": false,
|
||||
"load_in_4bit": false,
|
||||
"strict": false,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
"type": "alpaca"
|
||||
}
|
||||
],
|
||||
"dataset_prepared_path": "last_run_prepared",
|
||||
"val_set_size": 0.1,
|
||||
"output_dir": "./outputs/lora-out",
|
||||
"adapter": "lora",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": true,
|
||||
"eval_sample_packing": true,
|
||||
"pad_to_sequence_len": true,
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": [
|
||||
"gate_proj",
|
||||
"down_proj",
|
||||
"up_proj",
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"o_proj"
|
||||
],
|
||||
"gradient_accumulation_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.0002,
|
||||
"train_on_inputs": false,
|
||||
"group_by_length": false,
|
||||
"bf16": "auto",
|
||||
"tf32": false,
|
||||
"gradient_checkpointing": true,
|
||||
"logging_steps": 1,
|
||||
"flash_attention": true,
|
||||
"loss_watchdog_threshold": 5,
|
||||
"loss_watchdog_patience": 3,
|
||||
"warmup_steps": 10,
|
||||
"evals_per_epoch": 4,
|
||||
"saves_per_epoch": 1,
|
||||
"weight_decay": 0,
|
||||
"hub_model_id": "runpod/llama-fr-lora",
|
||||
"wandb_name": "test-run-1",
|
||||
"wandb_project": "test-run-1",
|
||||
"wandb_entity": "axo-test",
|
||||
"special_tokens": {
|
||||
"pad_token": "<|end_of_text|>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Advanced Features
|
||||
|
||||
#### Wandb Integration
|
||||
|
||||
- `wandb_project`: Project name for Weights & Biases
|
||||
- `wandb_entity`: Team name in W&B
|
||||
- `wandb_watch`: Monitor model with W&B
|
||||
- `wandb_name`: Name of the W&B run
|
||||
- `wandb_run_id`: ID for the W&B run
|
||||
|
||||
#### Performance Optimization
|
||||
|
||||
- `sample_packing`: Enable efficient sequence packing
|
||||
- `eval_sample_packing`: Use sequence packing during evaluation
|
||||
- `torch_compile`: Enable PyTorch 2.0 compilation
|
||||
- `flash_attention`: Use Flash Attention implementation
|
||||
- `xformers_attention`: Use xFormers attention implementation
|
||||
|
||||
### Available Optimizers
|
||||
|
||||
The following optimizers are supported:
|
||||
|
||||
- `adamw_hf`: HuggingFace's AdamW implementation
|
||||
- `adamw_torch`: PyTorch's AdamW
|
||||
- `adamw_torch_fused`: Fused AdamW implementation
|
||||
- `adamw_torch_xla`: XLA-optimized AdamW
|
||||
- `adamw_apex_fused`: NVIDIA Apex fused AdamW
|
||||
- `adafactor`: Adafactor optimizer
|
||||
- `adamw_anyprecision`: Anyprecision AdamW
|
||||
- `adamw_bnb_8bit`: 8-bit AdamW from bitsandbytes
|
||||
- `lion_8bit`: 8-bit Lion optimizer
|
||||
- `lion_32bit`: 32-bit Lion optimizer
|
||||
- `sgd`: Stochastic Gradient Descent
|
||||
- `adagrad`: Adagrad optimizer
|
||||
|
||||
## Notes
|
||||
|
||||
- Set `load_in_8bit: true` or `load_in_4bit: true` for memory-efficient training
|
||||
- Enable `flash_attention: true` for faster training on modern GPUs
|
||||
- Use `gradient_checkpointing: true` to reduce memory usage
|
||||
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
|
||||
|
||||
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config-reference.html).
|
||||
|
||||
### Errors:
|
||||
|
||||
- if you face any issues with the Flash Attention-2, Delete yoor worker and Re-start.
|
||||
@@ -1,93 +0,0 @@
|
||||
{
|
||||
"title": "Axolotl Fine-Tuning",
|
||||
"description": "Serverless fine-tuning of open-source LLMs with Axolotl. Supports LoRA, QLoRA, DPO, and more using Hugging Face models and datasets.",
|
||||
"type": "serverless",
|
||||
"category": "language",
|
||||
"iconUrl": "https://avatars.githubusercontent.com/u/167502477",
|
||||
"config": {
|
||||
"runsOn": "GPU",
|
||||
"containerDiskInGb": 200,
|
||||
"gpuCount": 1,
|
||||
"allowedCudaVersions": [
|
||||
"12.8",
|
||||
"12.7",
|
||||
"12.6",
|
||||
"12.5",
|
||||
"12.4"
|
||||
],
|
||||
"presets": [],
|
||||
"env": [
|
||||
{
|
||||
"key": "TOKENIZER",
|
||||
"input": {
|
||||
"name": "Tokenizer",
|
||||
"type": "string",
|
||||
"description": "Name or path of the Hugging Face tokenizer to use.",
|
||||
"default": "",
|
||||
"advanced": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "MAX_NUM_SEQS",
|
||||
"input": {
|
||||
"name": "Max Num Seqs",
|
||||
"type": "number",
|
||||
"description": "Maximum number of sequences per iteration.",
|
||||
"default": 256,
|
||||
"advanced": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "DISABLE_LOG_STATS",
|
||||
"input": {
|
||||
"name": "Disable Log Stats",
|
||||
"type": "boolean",
|
||||
"description": "Disable logging statistics.",
|
||||
"default": false,
|
||||
"trueValue": "true",
|
||||
"falseValue": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"key": "LOAD_FORMAT",
|
||||
"input": {
|
||||
"name": "Load Format",
|
||||
"type": "string",
|
||||
"description": "The format of the model weights to load.",
|
||||
"default": "auto",
|
||||
"options": [
|
||||
{
|
||||
"label": "auto",
|
||||
"value": "auto"
|
||||
},
|
||||
{
|
||||
"label": "pt",
|
||||
"value": "pt"
|
||||
},
|
||||
{
|
||||
"label": "safetensors",
|
||||
"value": "safetensors"
|
||||
},
|
||||
{
|
||||
"label": "npcache",
|
||||
"value": "npcache"
|
||||
},
|
||||
{
|
||||
"label": "dummy",
|
||||
"value": "dummy"
|
||||
},
|
||||
{
|
||||
"label": "tensorizer",
|
||||
"value": "tensorizer"
|
||||
},
|
||||
{
|
||||
"label": "bitsandbytes",
|
||||
"value": "bitsandbytes"
|
||||
}
|
||||
],
|
||||
"advanced": true
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
# Required Python packages get listed here, one per line.
|
||||
# Reccomended to lock the version number to avoid unexpected changes.
|
||||
|
||||
# You can also install packages from a git repository, e.g.:
|
||||
# git+https://github.com/runpod/runpod-python.git
|
||||
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
|
||||
runpod~=1.7.0
|
||||
@@ -1,564 +0,0 @@
|
||||
# # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# # This can also be a relative path to a model on disk
|
||||
# base_model: ./llama-7b-hf
|
||||
# # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
# base_model_ignore_patterns:
|
||||
# # If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# # You can set that here, or leave this empty to default to base_model
|
||||
# base_model_config: ./llama-7b-hf
|
||||
# # You can specify to choose a specific model revision from huggingface hub
|
||||
# model_revision:
|
||||
# # Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# # than the one defined in the base model
|
||||
# tokenizer_config:
|
||||
# # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
# model_type: AutoModelForCausalLM
|
||||
# # Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
# tokenizer_type: AutoTokenizer
|
||||
# # Trust remote code for untrusted source
|
||||
# trust_remote_code:
|
||||
# # use_fast option for tokenizer loading from_pretrained, default to True
|
||||
# tokenizer_use_fast:
|
||||
# # Whether to use the legacy tokenizer setting, defaults to True
|
||||
# tokenizer_legacy:
|
||||
# # Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# # This is reported to improve training speed on some models
|
||||
# resize_token_embeddings_to_32x:
|
||||
|
||||
# # Used to identify which the model is based on
|
||||
# is_falcon_derived_model:
|
||||
# is_llama_derived_model:
|
||||
# # Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
# is_mistral_derived_model:
|
||||
# is_qwen_derived_model:
|
||||
|
||||
# # optional overrides to the base model configuration
|
||||
# model_config:
|
||||
# # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
# rope_scaling:
|
||||
# type: # linear | dynamic
|
||||
# factor: # float
|
||||
|
||||
# # Whether you are training a 4-bit GPTQ quantized model
|
||||
# gptq: true
|
||||
# gptq_groupsize: 128 # group size
|
||||
# gptq_model_v1: false # v1 or v2
|
||||
|
||||
# # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
# load_in_8bit: true
|
||||
# # Use bitsandbytes 4 bit
|
||||
# load_in_4bit:
|
||||
|
||||
# # Use CUDA bf16
|
||||
# bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
|
||||
# # Use CUDA fp16
|
||||
# fp16: true
|
||||
# # Use CUDA tf32
|
||||
# tf32: true # require >=ampere
|
||||
|
||||
# # No AMP (automatic mixed precision)
|
||||
# bfloat16: true # require >=ampere
|
||||
# float16: true
|
||||
|
||||
# # A list of one or more datasets to finetune the model with
|
||||
# datasets:
|
||||
# # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
# - path: vicgalle/alpaca-gpt4
|
||||
# # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
# type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
# ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
# data_files: # Optional[str] path to source data files
|
||||
# shards: # Optional[int] number of shards to split data into
|
||||
# name: # Optional[str] name of dataset configuration to load
|
||||
# train_on_split: train # Optional[str] name of dataset split to load from
|
||||
|
||||
# # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
# conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
# field_human: # Optional[str]. Human key to use for conversation.
|
||||
# field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
|
||||
# # Custom user prompt
|
||||
# - path: repo
|
||||
# type:
|
||||
# # The below are defaults. only set what's needed.
|
||||
# system_prompt: ""
|
||||
# system_format: "{system}"
|
||||
# field_system: system
|
||||
# field_instruction: instruction
|
||||
# field_input: input
|
||||
# field_output: output
|
||||
|
||||
# # Customizable to be single line or multi-line
|
||||
# # 'format' can include {input}
|
||||
# format: |-
|
||||
# User: {instruction} {input}
|
||||
# Assistant:
|
||||
# # 'no_input_format' cannot include {input}
|
||||
# no_input_format: "{instruction} "
|
||||
|
||||
# # For `completion` datasets only, uses the provided field instead of `text` column
|
||||
# field:
|
||||
|
||||
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# # subsequent training attempts load faster, relative path
|
||||
# dataset_prepared_path: data/last_run_prepared
|
||||
# # Push prepared dataset to hub
|
||||
# push_dataset_to_hub: # repo path
|
||||
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# # if not set.
|
||||
# dataset_num_proc: # defaults to os.cpu_count() if not set
|
||||
# # push checkpoints to hub
|
||||
# hub_model_id: # repo path to push finetuned model
|
||||
# # how to push checkpoints to hub
|
||||
# # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
# hub_strategy:
|
||||
# # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# # Required to be true when used in combination with `push_dataset_to_hub`
|
||||
# hf_use_auth_token: # boolean
|
||||
# # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
# val_set_size: 0.04
|
||||
# # Num shards for whole dataset
|
||||
# dataset_shard_num:
|
||||
# # Index of shard to use for whole dataset
|
||||
# dataset_shard_idx:
|
||||
|
||||
# # The maximum length of an input to train with, this should typically be less than 2048
|
||||
# # as most models have a token/context limit of 2048
|
||||
# sequence_len: 2048
|
||||
# # Pad inputs so each step uses constant sized buffers
|
||||
# # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
# pad_to_sequence_len:
|
||||
# # Max sequence length to concatenate training samples together up to
|
||||
# # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# # FutureWarning: This will soon be DEPRECATED
|
||||
# max_packed_sequence_len: 1024
|
||||
# # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
# sample_packing:
|
||||
# # Set to 'false' if getting errors during eval with sample_packing on.
|
||||
# eval_sample_packing:
|
||||
# # You can set these packing optimizations AFTER starting a training at least once.
|
||||
# # The trainer will provide recommended values for these values.
|
||||
# sample_packing_eff_est:
|
||||
# total_num_tokens:
|
||||
|
||||
# # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
# adapter: lora
|
||||
# # If you already have a lora model trained that you want to load, put that here.
|
||||
# # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
# lora_model_dir:
|
||||
|
||||
# # LoRA hyperparameters
|
||||
# # For more details about the following options, see:
|
||||
# # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
# lora_r: 8
|
||||
# lora_alpha: 16
|
||||
# lora_dropout: 0.05
|
||||
# lora_target_modules:
|
||||
# - q_proj
|
||||
# - v_proj
|
||||
# # - k_proj
|
||||
# # - o_proj
|
||||
# # - gate_proj
|
||||
# # - down_proj
|
||||
# # - up_proj
|
||||
# lora_target_linear: # If true, will target all linear layers
|
||||
|
||||
# # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
# lora_modules_to_save:
|
||||
# # - embed_tokens
|
||||
# # - lm_head
|
||||
|
||||
# # Once you complete training, the model will be saved to the following directory.
|
||||
# # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
# lora_out_dir:
|
||||
# lora_fan_in_fan_out: false
|
||||
|
||||
# # ReLoRA configuration
|
||||
# # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
# relora_steps: # Number of steps per ReLoRA restart
|
||||
# relora_warmup_steps: # Number of per-restart warmup steps
|
||||
# relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# # wandb configuration if you're using it
|
||||
# wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
# wandb_project: # Your wandb project name
|
||||
# wandb_entity: # A wandb Team name if using a Team
|
||||
# wandb_watch:
|
||||
# wandb_run_id: # Set the name of your wandb run
|
||||
# wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# # Where to save the full-finetuned model to
|
||||
# output_dir: ./completed-model
|
||||
|
||||
# # Whether to use torch.compile and which backend to use
|
||||
# torch_compile: # bool
|
||||
# torch_compile_backend: # Optional[str]
|
||||
|
||||
# # Training hyperparameters
|
||||
|
||||
# # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
# gradient_accumulation_steps: 1
|
||||
# # The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
# micro_batch_size: 2
|
||||
# eval_batch_size:
|
||||
# num_epochs: 4
|
||||
# warmup_steps: 100 # cannot use with warmup_ratio
|
||||
# warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
# learning_rate: 0.00003
|
||||
# lr_quadratic_warmup:
|
||||
# logging_steps:
|
||||
# save_strategy: # Set to `no` to skip checkpoint saves
|
||||
# save_steps: # Leave empty to save at each epoch
|
||||
# eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
# save_total_limit: # Checkpoints saved at a time
|
||||
# # Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# # if both are set, num_epochs will not be guaranteed.
|
||||
# # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
# max_steps:
|
||||
|
||||
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
# # Whether to mask out or include the human's prompt from the training labels
|
||||
# train_on_inputs: false
|
||||
# # Group similarly sized data to minimize padding.
|
||||
# # May be slower to start, as it must download and sort the entire dataset.
|
||||
# # Note that training loss may have an oscillating pattern with this enabled.
|
||||
# group_by_length: false
|
||||
|
||||
# # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
# gradient_checkpointing: false
|
||||
|
||||
# # Stop training after this many evaluation losses have increased in a row
|
||||
# # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
# early_stopping_patience: 3
|
||||
|
||||
# # Specify a scheduler and kwargs to use with the optimizer
|
||||
# lr_scheduler: # 'one_cycle' | empty for cosine
|
||||
# lr_scheduler_kwargs:
|
||||
|
||||
# # For one_cycle optim
|
||||
# lr_div_factor: # Learning rate div factor
|
||||
|
||||
# # Specify optimizer
|
||||
# # Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
# #
|
||||
# # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||
# # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||
# # in the examples/ for your model and fine-tuning use case.
|
||||
# #
|
||||
# # Valid values for 'optimizer' include:
|
||||
# # - adamw_hf
|
||||
# # - adamw_torch
|
||||
# # - adamw_torch_fused
|
||||
# # - adamw_torch_xla
|
||||
# # - adamw_apex_fused
|
||||
# # - adafactor
|
||||
# # - adamw_anyprecision
|
||||
# # - sgd
|
||||
# # - adagrad
|
||||
# # - adamw_bnb_8bit
|
||||
# # - lion_8bit
|
||||
# # - lion_32bit
|
||||
# # - paged_adamw_32bit
|
||||
# # - paged_adamw_8bit
|
||||
# # - paged_lion_32bit
|
||||
# # - paged_lion_8bit
|
||||
# optimizer:
|
||||
# # Specify weight decay
|
||||
# weight_decay:
|
||||
# # adamw hyperparams
|
||||
# adam_beta1:
|
||||
# adam_beta2:
|
||||
# adam_epsilon:
|
||||
# # Gradient clipping max norm
|
||||
# max_grad_norm:
|
||||
|
||||
# # Augmentation techniques
|
||||
# # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# # currently only supported on Llama and Mistral
|
||||
# noisy_embedding_alpha:
|
||||
|
||||
# # Whether to bettertransformers
|
||||
# flash_optimum:
|
||||
# # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
# xformers_attention:
|
||||
# # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# flash_attention:
|
||||
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# # Whether to use scaled-dot-product attention
|
||||
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
# sdp_attention:
|
||||
# # Landmark attention (only llama)
|
||||
# landmark_attention:
|
||||
# # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# # LLaMA only
|
||||
# xpos_rope:
|
||||
|
||||
# # Resume from a specific checkpoint dir
|
||||
# resume_from_checkpoint:
|
||||
# # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# # Be careful with this being turned on between different models.
|
||||
# auto_resume_from_checkpoints: false
|
||||
|
||||
# # Don't mess with this, it's here for accelerate and torchrun
|
||||
# local_rank:
|
||||
|
||||
# # Add or change special tokens.
|
||||
# # If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
# special_tokens:
|
||||
# # bos_token: "<s>"
|
||||
# # eos_token: "</s>"
|
||||
# # unk_token: "<unk>"
|
||||
|
||||
# # Add extra tokens.
|
||||
# tokens:
|
||||
|
||||
# # FSDP
|
||||
# fsdp:
|
||||
# fsdp_config:
|
||||
|
||||
# # Deepspeed config path. e.g., deepspeed/zero3.json
|
||||
# deepspeed:
|
||||
|
||||
# # Advanced DDP Arguments
|
||||
# ddp_timeout:
|
||||
# ddp_bucket_cap_mb:
|
||||
# ddp_broadcast_buffers:
|
||||
|
||||
# # Path to torch distx for optim 'adamw_anyprecision'
|
||||
# torchdistx_path:
|
||||
|
||||
# # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||
# pretraining_dataset:
|
||||
|
||||
# # Debug mode
|
||||
# debug:
|
||||
|
||||
# # Seed
|
||||
# seed:
|
||||
|
||||
# # Allow overwrite yml config using from cli
|
||||
# strict:
|
||||
|
||||
base_model: ${BASE_MODEL}
|
||||
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
|
||||
base_model_config: ${BASE_MODEL_CONFIG}
|
||||
revision_of_model: ${REVISION_OF_MODEL}
|
||||
tokenizer_config: ${TOKENIZER_CONFIG}
|
||||
model_type: ${MODEL_TYPE}
|
||||
tokenizer_type: ${TOKENIZER_TYPE}
|
||||
trust_remote_code: ${TRUST_REMOTE_CODE}
|
||||
tokenizer_use_fast: ${TOKENIZER_USE_FAST}
|
||||
tokenizer_legacy: ${TOKENIZER_LEGACY}
|
||||
resize_token_embeddings_to_32x: ${RESIZE_TOKEN_EMBEDDINGS_TO_32X}
|
||||
|
||||
is_falcon_derived_model: ${IS_FALCON_DERIVED_MODEL}
|
||||
is_llama_derived_model: ${IS_LLAMA_DERIVED_MODEL}
|
||||
is_qwen_derived_model: ${IS_QWEN_DERIVED_MODEL}
|
||||
is_mistral_derived_model: ${IS_MISTRAL_DERIVED_MODEL}
|
||||
|
||||
overrides_of_model_config:
|
||||
rope_scaling:
|
||||
type: ${ROPE_SCALING_TYPE}
|
||||
factor: ${ROPE_SCALING_FACTOR}
|
||||
|
||||
bnb_config_kwargs:
|
||||
llm_int8_has_fp16_weight: ${BNB_LLM_INT8_HAS_FP16_WEIGHT}
|
||||
bnb_4bit_quant_type: ${BNB_4BIT_QUANT_TYPE}
|
||||
bnb_4bit_use_double_quant: ${BNB_4BIT_USE_DOUBLE_QUANT}
|
||||
|
||||
gptq: ${GPTQ}
|
||||
load_in_8bit: ${LOAD_IN_8BIT}
|
||||
load_in_4bit: ${LOAD_IN_4BIT}
|
||||
bf16: ${BF16}
|
||||
fp16: ${FP16}
|
||||
tf32: ${TF32}
|
||||
bfloat16: ${BFLOAT16}
|
||||
float16: ${FLOAT16}
|
||||
|
||||
gpu_memory_limit: ${GPU_MEMORY_LIMIT}
|
||||
lora_on_cpu: ${LORA_ON_CPU}
|
||||
|
||||
datasets:
|
||||
- path: ${DATASET_PATH}
|
||||
type: ${DATASET_TYPE}
|
||||
ds_type: ${DATASET_DS_TYPE}
|
||||
data_files: ${DATASET_DATA_FILES}
|
||||
shards: ${DATASET_SHARDS}
|
||||
name: ${DATASET_NAME}
|
||||
train_on_split: ${DATASET_TRAIN_ON_SPLIT}
|
||||
revision: ${DATASET_REVISION}
|
||||
trust_remote_code: ${DATASET_TRUST_REMOTE_CODE}
|
||||
|
||||
rl: ${RL}
|
||||
dpo_use_weighting: ${DPO_USE_WEIGHTING}
|
||||
|
||||
chat_template: ${CHAT_TEMPLATE}
|
||||
chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
|
||||
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
|
||||
dataset_prepared_path: ${DATASET_PREPARED_PATH}
|
||||
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
|
||||
dataset_num_proc: ${DATASET_NUM_PROC}
|
||||
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
|
||||
hub_model_id: ${HUB_MODEL_ID}
|
||||
hub_strategy: ${HUB_STRATEGY}
|
||||
hf_use_auth_token: ${HF_USE_AUTH_TOKEN}
|
||||
val_set_size: ${VAL_SET_SIZE}
|
||||
dataset_shard_num: ${DATASET_SHARD_NUM}
|
||||
dataset_shard_idx: ${DATASET_SHARD_IDX}
|
||||
|
||||
sequence_len: ${SEQUENCE_LEN}
|
||||
pad_to_sequence_len: ${PAD_TO_SEQUENCE_LEN}
|
||||
sample_packing: ${SAMPLE_PACKING}
|
||||
eval_sample_packing: ${EVAL_SAMPLE_PACKING}
|
||||
sample_packing_eff_est: ${SAMPLE_PACKING_EFF_EST}
|
||||
total_num_tokens: ${TOTAL_NUM_TOKENS}
|
||||
sample_packing_group_size: ${SAMPLE_PACKING_GROUP_SIZE}
|
||||
sample_packing_bin_size: ${SAMPLE_PACKING_BIN_SIZE}
|
||||
|
||||
batch_flattening: ${BATCH_FLATTENING}
|
||||
device_map: ${DEVICE_MAP}
|
||||
max_memory: ${MAX_MEMORY}
|
||||
|
||||
adapter: ${ADAPTER}
|
||||
lora_model_dir: ${LORA_MODEL_DIR}
|
||||
|
||||
lora_r: ${LORA_R}
|
||||
lora_alpha: ${LORA_ALPHA}
|
||||
lora_dropout: ${LORA_DROPOUT}
|
||||
lora_target_modules:
|
||||
- ${LORA_TARGET_MODULES}
|
||||
lora_target_linear: ${LORA_TARGET_LINEAR}
|
||||
peft_layers_to_transform: ${PEFT_LAYERS_TO_TRANSFORM}
|
||||
lora_modules_to_save: ${LORA_MODULES_TO_SAVE}
|
||||
lora_fan_in_fan_out: ${LORA_FAN_IN_FAN_OUT}
|
||||
|
||||
loraplus_lr_ratio: ${LORAPLUS_LR_RATIO}
|
||||
loraplus_lr_embedding: ${LORAPLUS_LR_EMBEDDING}
|
||||
|
||||
peft:
|
||||
loftq_config:
|
||||
loftq_bits: ${LOFTQ_BITS}
|
||||
|
||||
relora_steps: ${RELORA_STEPS}
|
||||
relora_warmup_steps: ${RELORA_WARMUP_STEPS}
|
||||
relora_anneal_steps: ${RELORA_ANNEAL_STEPS}
|
||||
relora_prune_ratio: ${RELORA_PRUNE_RATIO}
|
||||
relora_cpu_offload: ${RELORA_CPU_OFFLOAD}
|
||||
|
||||
wandb_mode: ${WANDB_MODE}
|
||||
wandb_project: ${WANDB_PROJECT}
|
||||
wandb_entity: ${WANDB_ENTITY}
|
||||
wandb_watch: ${WANDB_WATCH}
|
||||
wandb_name: ${WANDB_NAME}
|
||||
wandb_run_id: ${WANDB_RUN_ID}
|
||||
wandb_log_model: ${WANDB_LOG_MODEL}
|
||||
|
||||
mlflow_tracking_uri: ${MLFLOW_TRACKING_URI}
|
||||
mlflow_experiment_name: ${MLFLOW_EXPERIMENT_NAME}
|
||||
mlflow_run_name: ${MLFLOW_RUN_NAME}
|
||||
hf_mlflow_log_artifacts: ${HF_MLFLOW_LOG_ARTIFACTS}
|
||||
|
||||
use_comet: ${USE_COMET}
|
||||
comet_api_key: ${COMET_API_KEY}
|
||||
comet_workspace: ${COMET_WORKSPACE}
|
||||
comet_project_name: ${COMET_PROJECT_NAME}
|
||||
comet_experiment_key: ${COMET_EXPERIMENT_KEY}
|
||||
comet_mode: ${COMET_MODE}
|
||||
comet_online: ${COMET_ONLINE}
|
||||
comet_experiment_config: ${COMET_EXPERIMENT_CONFIG}
|
||||
|
||||
output_dir: ${OUTPUT_DIR}
|
||||
|
||||
torch_compile: ${TORCH_COMPILE}
|
||||
torch_compile_backend: ${TORCH_COMPILE_BACKEND}
|
||||
|
||||
gradient_accumulation_steps: ${GRADIENT_ACCUMULATION_STEPS}
|
||||
micro_batch_size: ${MICRO_BATCH_SIZE}
|
||||
eval_batch_size: ${EVAL_BATCH_SIZE}
|
||||
num_epochs: ${NUM_EPOCHS}
|
||||
warmup_steps: ${WARMUP_STEPS}
|
||||
warmup_ratio: ${WARMUP_RATIO}
|
||||
learning_rate: ${LEARNING_RATE}
|
||||
lr_quadratic_warmup: ${LR_QUADRATIC_WARMUP}
|
||||
logging_steps: ${LOGGING_STEPS}
|
||||
eval_steps: ${EVAL_STEPS}
|
||||
evals_per_epoch: ${EVALS_PER_EPOCH}
|
||||
save_strategy: ${SAVE_STRATEGY}
|
||||
save_steps: ${SAVE_STEPS}
|
||||
saves_per_epoch: ${SAVES_PER_EPOCH}
|
||||
save_total_limit: ${SAVE_TOTAL_LIMIT}
|
||||
max_steps: ${MAX_STEPS}
|
||||
|
||||
eval_table_size: ${EVAL_TABLE_SIZE}
|
||||
eval_max_new_tokens: ${EVAL_MAX_NEW_TOKENS}
|
||||
eval_causal_lm_metrics: ${EVAL_CAUSAL_LM_METRICS}
|
||||
|
||||
profiler_steps: ${PROFILER_STEPS}
|
||||
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
|
||||
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
|
||||
|
||||
train_on_inputs: ${TRAIN_ON_INPUTS}
|
||||
group_by_length: ${GROUP_BY_LENGTH}
|
||||
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
|
||||
early_stopping_patience: ${EARLY_STOPPING_PATIENCE}
|
||||
|
||||
lr_scheduler: ${LR_SCHEDULER}
|
||||
lr_scheduler_kwargs: ${LR_SCHEDULER_KWARGS}
|
||||
cosine_min_lr_ratio: ${COSINE_MIN_LR_RATIO}
|
||||
cosine_constant_lr_ratio: ${COSINE_CONSTANT_LR_RATIO}
|
||||
lr_div_factor: ${LR_DIV_FACTOR}
|
||||
|
||||
optimizer: ${OPTIMIZER}
|
||||
optim_args: ${OPTIM_ARGS}
|
||||
optim_target_modules: ${OPTIM_TARGET_MODULES}
|
||||
weight_decay: ${WEIGHT_DECAY}
|
||||
adam_beta1: ${ADAM_BETA1}
|
||||
adam_beta2: ${ADAM_BETA2}
|
||||
adam_epsilon: ${ADAM_EPSILON}
|
||||
max_grad_norm: ${MAX_GRAD_NORM}
|
||||
|
||||
neftune_noise_alpha: ${NEFTUNE_NOISE_ALPHA}
|
||||
|
||||
flash_optimum: ${FLASH_OPTIMUM}
|
||||
xformers_attention: ${XFORMERS_ATTENTION}
|
||||
flash_attention: ${FLASH_ATTENTION}
|
||||
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
||||
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
||||
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
||||
sdp_attention: ${SDP_ATTENTION}
|
||||
s2_attention: ${S2_ATTENTION}
|
||||
resume_from_checkpoint: ${RESUME_FROM_CHECKPOINT}
|
||||
auto_resume_from_checkpoints: ${AUTO_RESUME_FROM_CHECKPOINTS}
|
||||
|
||||
local_rank: ${LOCAL_RANK}
|
||||
|
||||
special_tokens:
|
||||
bos_token: ${SPECIAL_TOKEN_BOS}
|
||||
eos_token: ${SPECIAL_TOKEN_EOS}
|
||||
unk_token: ${SPECIAL_TOKEN_UNK}
|
||||
pad_token: ${SPECIAL_TOKEN_PAD}
|
||||
|
||||
tokens: ${TOKENS}
|
||||
|
||||
fsdp: ${FSDP}
|
||||
fsdp_config: ${FSDP_CONFIG}
|
||||
deepspeed: ${DEEPSPEED}
|
||||
|
||||
ddp_timeout: ${DDP_TIMEOUT}
|
||||
ddp_bucket_cap_mb: ${DDP_BUCKET_CAP_MB}
|
||||
ddp_broadcast_buffers: ${DDP_BROADCAST_BUFFERS}
|
||||
|
||||
torchdistx_path: ${TORCHDISTX_PATH}
|
||||
pretraining_dataset: ${PRETRAINING_DATASET}
|
||||
debug: ${DEBUG}
|
||||
seed: ${SEED}
|
||||
strict: ${STRICT}
|
||||
@@ -1,66 +0,0 @@
|
||||
"""
|
||||
Runpod serverless entrypoint handler
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import runpod
|
||||
import yaml
|
||||
from huggingface_hub._login import login
|
||||
from train import train
|
||||
from utils import get_output_dir
|
||||
|
||||
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
|
||||
if not os.path.exists(BASE_VOLUME):
|
||||
os.makedirs(BASE_VOLUME)
|
||||
|
||||
logger = runpod.RunPodLogger()
|
||||
|
||||
|
||||
async def handler(job):
|
||||
runpod_job_id = job["id"]
|
||||
inputs = job["input"]
|
||||
run_id = inputs.get("run_id", "default_run_id")
|
||||
args = inputs.get("args", {})
|
||||
|
||||
# Set output directory
|
||||
output_dir = os.path.join(BASE_VOLUME, get_output_dir(run_id))
|
||||
args["output_dir"] = output_dir
|
||||
|
||||
# First save args to a temporary config file
|
||||
config_path = "/workspace/test_config.yaml"
|
||||
|
||||
# Add run_name and job_id to args before saving
|
||||
args["run_name"] = run_id
|
||||
args["runpod_job_id"] = runpod_job_id
|
||||
|
||||
yaml_data = yaml.dump(args, default_flow_style=False)
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
file.write(yaml_data)
|
||||
|
||||
# Handle credentials
|
||||
credentials = inputs.get("credentials", {})
|
||||
|
||||
if "wandb_api_key" in credentials:
|
||||
os.environ["WANDB_API_KEY"] = credentials["wandb_api_key"]
|
||||
if "hf_token" in credentials:
|
||||
os.environ["HF_TOKEN"] = credentials["hf_token"]
|
||||
|
||||
if os.environ.get("HF_TOKEN"):
|
||||
login(token=os.environ["HF_TOKEN"])
|
||||
else:
|
||||
logger.info("No HF_TOKEN provided. Skipping login.")
|
||||
|
||||
logger.info("Starting Training.")
|
||||
async for result in train(config_path): # Pass the config path instead of args
|
||||
logger.info(result)
|
||||
logger.info("Training Complete.")
|
||||
|
||||
# Cleanup
|
||||
if "WANDB_API_KEY" in os.environ:
|
||||
del os.environ["WANDB_API_KEY"]
|
||||
if "HF_TOKEN" in os.environ:
|
||||
del os.environ["HF_TOKEN"]
|
||||
|
||||
|
||||
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})
|
||||
@@ -1,61 +0,0 @@
|
||||
{
|
||||
"input": {
|
||||
"user_id": "user",
|
||||
"model_id": "llama-test",
|
||||
"run_id": "llama-test",
|
||||
"credentials": {
|
||||
"wandb_api_key": "",
|
||||
"hf_token": ""
|
||||
},
|
||||
"args": {
|
||||
"base_model": "NousResearch/Meta-Llama-3-8B",
|
||||
"model_type": "LlamaForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"load_in_8bit": true,
|
||||
"load_in_4bit": false,
|
||||
"strict": false,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca"
|
||||
}
|
||||
],
|
||||
"val_set_size": 0.05,
|
||||
"output_dir": "./outputs/lora-out",
|
||||
"sequence_len": 4096,
|
||||
"sample_packing": true,
|
||||
"eval_sample_packing": false,
|
||||
"pad_to_sequence_len": true,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": true,
|
||||
"lora_modules_to_save": [
|
||||
"embed_tokens",
|
||||
"lm_head"
|
||||
],
|
||||
"gradient_accumulation_steps": 4,
|
||||
"micro_batch_size": 2,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.0002,
|
||||
"train_on_inputs": false,
|
||||
"group_by_length": false,
|
||||
"bf16": "auto",
|
||||
"tf32": false,
|
||||
"gradient_checkpointing": true,
|
||||
"logging_steps": 1,
|
||||
"flash_attention": true,
|
||||
"warmup_steps": 1,
|
||||
"evals_per_epoch": 1,
|
||||
"eval_max_new_tokens": 128,
|
||||
"saves_per_epoch": 1,
|
||||
"weight_decay": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|end_of_text|>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
Runpod train entrypoint
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
|
||||
"""
|
||||
Run preprocessing (if enabled) and training with the given config file
|
||||
:param config_path: Path to the YAML config file
|
||||
:param gpu_id: GPU ID to use (default: "0")
|
||||
:param preprocess: Whether to run preprocessing (default: True)
|
||||
|
||||
"""
|
||||
# First check if preprocessing is needed
|
||||
if preprocess:
|
||||
# Preprocess command
|
||||
preprocess_cmd = (
|
||||
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
|
||||
)
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
preprocess_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
if process.stdout is not None:
|
||||
async for line in process.stdout:
|
||||
yield f"Preprocessing: {line.decode().strip()}"
|
||||
await process.wait()
|
||||
yield "Preprocessing completed."
|
||||
else:
|
||||
yield "Skipping preprocessing step."
|
||||
|
||||
# Training command
|
||||
train_cmd = f"axolotl train {config_path}"
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
|
||||
)
|
||||
|
||||
if process.stdout is not None:
|
||||
async for line in process.stdout:
|
||||
yield f"Training: {line.decode().strip()}"
|
||||
await process.wait()
|
||||
@@ -1,89 +0,0 @@
|
||||
"""
|
||||
Runpod launcher utils
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def get_output_dir(run_id):
|
||||
path = f"fine-tuning/{run_id}"
|
||||
return path
|
||||
|
||||
|
||||
def make_valid_config(input_args):
|
||||
"""
|
||||
Creates and saves updated config file, returns the path to the new config
|
||||
:param input_args: dict of input args
|
||||
:return: str, path to the updated config file
|
||||
"""
|
||||
# Load default config
|
||||
with open("config/config.yaml", "r", encoding="utf-8") as fin:
|
||||
all_args = yaml.safe_load(fin)
|
||||
|
||||
if not input_args:
|
||||
print("No args provided, using defaults")
|
||||
else:
|
||||
all_args.update(input_args)
|
||||
|
||||
# Create updated config path
|
||||
updated_config_path = "config/updated_config.yaml"
|
||||
|
||||
# Save updated config to new file
|
||||
with open(updated_config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(all_args, f)
|
||||
|
||||
return updated_config_path
|
||||
|
||||
|
||||
def set_config_env_vars(args: dict):
|
||||
"""
|
||||
Convert API arguments into environment variables.
|
||||
Handles nested dictionaries, lists, and special values.
|
||||
|
||||
Args:
|
||||
args (dict): The arguments dictionary from the API request
|
||||
"""
|
||||
|
||||
def process_value(value):
|
||||
"""Convert Python values to string format for environment variables"""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
if isinstance(value, (list, dict)):
|
||||
return str(value)
|
||||
return str(value)
|
||||
|
||||
def set_env_vars(data, prefix=""):
|
||||
"""Recursively set environment variables from nested dictionary"""
|
||||
for key, value in data.items():
|
||||
env_key = prefix + key.upper()
|
||||
|
||||
# Handle special cases
|
||||
if isinstance(value, dict):
|
||||
# For nested dictionaries (like special_tokens)
|
||||
set_env_vars(value, f"{env_key}_")
|
||||
elif isinstance(value, list):
|
||||
# Handle list of dictionaries (like datasets)
|
||||
if value and isinstance(value[0], dict):
|
||||
for i, item in enumerate(value):
|
||||
set_env_vars(item, f"{env_key}_{i}_")
|
||||
else:
|
||||
# For simple lists (like lora_target_modules)
|
||||
os.environ[env_key] = process_value(value)
|
||||
else:
|
||||
# Handle all other cases
|
||||
os.environ[env_key] = process_value(value)
|
||||
|
||||
# Clear any existing related environment variables
|
||||
# This prevents old values from persisting
|
||||
for key in list(os.environ.keys()):
|
||||
if key.startswith(
|
||||
("BASE_MODEL", "MODEL_TYPE", "TOKENIZER_TYPE", "DATASET", "LORA_", "WANDB_")
|
||||
):
|
||||
del os.environ[key]
|
||||
|
||||
# Set new environment variables
|
||||
set_env_vars(args)
|
||||
@@ -1,86 +0,0 @@
|
||||
{
|
||||
"input": {
|
||||
"name": "quick_smoke_test_sft",
|
||||
"user_id": "user",
|
||||
"model_id": "llama-test",
|
||||
"run_id": "llama-test",
|
||||
"credentials": {
|
||||
"wandb_api_key": "",
|
||||
"hf_token": ""
|
||||
},
|
||||
"args": {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"load_in_4bit": true,
|
||||
"strict": false,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]"
|
||||
}
|
||||
],
|
||||
"val_set_size": 0.02,
|
||||
"output_dir": "./outputs/lora-out",
|
||||
"sequence_len": 4096,
|
||||
"sample_packing": true,
|
||||
"eval_sample_packing": false,
|
||||
"pad_to_sequence_len": true,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": true,
|
||||
"lora_modules_to_save": [
|
||||
"embed_tokens",
|
||||
"lm_head"
|
||||
],
|
||||
"gradient_accumulation_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.0002,
|
||||
"train_on_inputs": false,
|
||||
"group_by_length": false,
|
||||
"bf16": "auto",
|
||||
"tf32": true,
|
||||
"gradient_checkpointing": true,
|
||||
"logging_steps": 1,
|
||||
"flash_attention": true,
|
||||
"warmup_steps": 1,
|
||||
"evals_per_epoch": 1,
|
||||
"eval_max_new_tokens": 128,
|
||||
"saves_per_epoch": 1,
|
||||
"weight_decay": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>"
|
||||
},
|
||||
"max_steps": 20
|
||||
},
|
||||
"timeout": 100000
|
||||
},
|
||||
"config": {
|
||||
"gpuTypeId": "NVIDIA GeForce RTX 4090",
|
||||
"gpuCount": 1,
|
||||
"containerDiskInGb": 200,
|
||||
"env": [
|
||||
{
|
||||
"key": "TOKENIZER",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"key": "DISABLE_LOG_STATS",
|
||||
"value": "true"
|
||||
}
|
||||
],
|
||||
"allowedCudaVersions": [
|
||||
"12.8",
|
||||
"12.7",
|
||||
"12.6",
|
||||
"12.5",
|
||||
"12.4"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
{
|
||||
"tests": [
|
||||
{
|
||||
"name": "quick_smoke_test_sft",
|
||||
"input": {
|
||||
"user_id": "user",
|
||||
"model_id": "llama-test",
|
||||
"run_id": "llama-test",
|
||||
"credentials": {
|
||||
"wandb_api_key": "",
|
||||
"hf_token": ""
|
||||
},
|
||||
"args": {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"load_in_4bit": true,
|
||||
"strict": false,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]"
|
||||
}
|
||||
],
|
||||
"val_set_size": 0.02,
|
||||
"output_dir": "./outputs/lora-out",
|
||||
"sequence_len": 4096,
|
||||
"sample_packing": true,
|
||||
"eval_sample_packing": false,
|
||||
"pad_to_sequence_len": true,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": true,
|
||||
"lora_modules_to_save": [
|
||||
"embed_tokens",
|
||||
"lm_head"
|
||||
],
|
||||
"gradient_accumulation_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.0002,
|
||||
"train_on_inputs": false,
|
||||
"group_by_length": false,
|
||||
"bf16": "auto",
|
||||
"tf32": true,
|
||||
"gradient_checkpointing": true,
|
||||
"logging_steps": 1,
|
||||
"flash_attention": true,
|
||||
"warmup_steps": 1,
|
||||
"evals_per_epoch": 1,
|
||||
"eval_max_new_tokens": 128,
|
||||
"saves_per_epoch": 1,
|
||||
"weight_decay": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>"
|
||||
},
|
||||
"max_steps": 20
|
||||
}
|
||||
},
|
||||
"timeout": 100000
|
||||
}
|
||||
],
|
||||
"config": {
|
||||
"gpuTypeId": "NVIDIA GeForce RTX 4090",
|
||||
"gpuCount": 1,
|
||||
"containerDiskInGb": 200,
|
||||
"env": [
|
||||
{
|
||||
"key": "TOKENIZER",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"key": "DISABLE_LOG_STATS",
|
||||
"value": "true"
|
||||
}
|
||||
],
|
||||
"allowedCudaVersions": [
|
||||
"12.8",
|
||||
"12.7",
|
||||
"12.6",
|
||||
"12.5",
|
||||
"12.4"
|
||||
]
|
||||
}
|
||||
}
|
||||
94
AGENTS.md
94
AGENTS.md
@@ -1,94 +0,0 @@
|
||||
# Axolotl
|
||||
|
||||
Fine-tuning framework for LLMs. Config-driven: every training run is defined by a single YAML file.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
Python, PyTorch, HuggingFace Transformers, TRL, PEFT (LoRA/QLoRA), DeepSpeed, FSDP, vLLM (for GRPO generation).
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
axolotl train config.yaml # Train (single or multi-GPU, auto-detected)
|
||||
axolotl preprocess config.yaml # Tokenize dataset and validate config
|
||||
axolotl preprocess config.yaml --debug # Inspect tokenized samples and label masking
|
||||
axolotl inference config.yaml # Interactive inference
|
||||
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
|
||||
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
|
||||
axolotl fetch examples # Download example configs
|
||||
```
|
||||
|
||||
## Training Methods
|
||||
|
||||
| Method | Config Key | When to Use |
|
||||
|--------|-----------|-------------|
|
||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||
| EBFT | `rl: ebft` | Feature-matching rewards from internal representations |
|
||||
|
||||
Agent-specific references:
|
||||
- [docs/agents/sft.md](docs/agents/sft.md) — supervised fine-tuning
|
||||
- [docs/agents/preference_tuning.md](docs/agents/preference_tuning.md) — DPO, IPO, KTO, ORPO, SimPO
|
||||
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
||||
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
||||
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
||||
|
||||
## Config Pattern
|
||||
|
||||
All training is config-driven. A YAML file specifies model, adapter, dataset(s), and hyperparameters:
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.1-8B-Instruct
|
||||
adapter: lora # or qlora, or omit for full fine-tune
|
||||
datasets:
|
||||
- path: my_dataset
|
||||
type: chat_template # prompt strategy (see docs/dataset-formats/)
|
||||
output_dir: ./outputs/lora-out
|
||||
```
|
||||
|
||||
Config schema: `src/axolotl/utils/schemas/config.py` (AxolotlInputConfig).
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
cli/ # CLI entry points (train, preprocess, inference, merge_lora, vllm_serve)
|
||||
core/
|
||||
builders/ # TrainerBuilder classes (causal.py for SFT, rl.py for RLHF)
|
||||
trainers/ # Trainer classes, mixins (optimizer, scheduler, packing)
|
||||
dpo/ # DPO trainer and config
|
||||
grpo/ # GRPO trainer and sampler
|
||||
loaders/ # Model, tokenizer, adapter, processor loading
|
||||
prompt_strategies/ # Dataset format handlers (chat_template, alpaca, dpo/, kto/, orpo/)
|
||||
utils/schemas/ # Pydantic config schemas (config, model, training, peft, trl, fsdp)
|
||||
integrations/ # Plugins (liger, cut_cross_entropy, swanlab, nemo_gym)
|
||||
monkeypatch/ # Runtime patches for HF transformers
|
||||
|
||||
examples/ # Example YAML configs by model (llama-3/, qwen2/, mistral/, ebft/)
|
||||
deepspeed_configs/ # DeepSpeed JSON configs (zero2, zero3)
|
||||
docs/ # Quarto documentation site
|
||||
```
|
||||
|
||||
## Code Conventions
|
||||
|
||||
- Config-driven: features are toggled via YAML, not code changes
|
||||
- Prompt strategies: `src/axolotl/prompt_strategies/` — each `type:` value maps to a function
|
||||
- Plugin system: `plugins:` list in config loads integration modules
|
||||
- Trainer mixins: `core/trainers/mixins/` for composable trainer behaviors
|
||||
- Schemas: all config validation via Pydantic in `utils/schemas/`
|
||||
|
||||
## Key Documentation
|
||||
|
||||
- [Getting Started](docs/getting-started.qmd) — quickstart tutorial
|
||||
- [Choosing a Method](docs/choosing_method.qmd) — SFT vs DPO vs GRPO decision guide
|
||||
- [Config Reference](docs/config-reference.qmd) — all config options
|
||||
- [Dataset Formats](docs/dataset-formats/) — chat_template, alpaca, input_output, completion
|
||||
- [RLHF](docs/rlhf.qmd) — DPO, KTO, ORPO, GRPO, EBFT configs and dataset formats
|
||||
- [GRPO Deep Dive](docs/grpo.qmd) — async training, custom rewards, scaling
|
||||
- [vLLM Serving](docs/vllm_serving.qmd) — vLLM setup for GRPO/EBFT
|
||||
- [Multi-GPU](docs/multi-gpu.qmd) — FSDP and DeepSpeed
|
||||
- [Training Stability](docs/training_stability.qmd) — debugging loss, NaN, OOM
|
||||
- [Debugging](docs/debugging.qmd) — VSCode setup, Docker debugging
|
||||
10
CITATION.cff
10
CITATION.cff
@@ -1,10 +0,0 @@
|
||||
cff-version: 1.2.0
|
||||
type: software
|
||||
title: "Axolotl: Open Source LLM Post-Training"
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "Axolotl maintainers and contributors"
|
||||
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
|
||||
url: "https://axolotl.ai/"
|
||||
license: Apache-2.0
|
||||
date-released: "2023-05-30"
|
||||
@@ -2,5 +2,4 @@ include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||
recursive-include axolotl *.py
|
||||
|
||||
165
README.md
165
README.md
@@ -5,14 +5,10 @@
|
||||
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
|
||||
</picture>
|
||||
</p>
|
||||
<p align="center">
|
||||
<strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
|
||||
<a href="https://codecov.io/gh/axolotl-ai-cloud/axolotl"><img src="https://codecov.io/gh/axolotl-ai-cloud/axolotl/branch/main/graph/badge.svg" alt="codecov"></a>
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
|
||||
<br/>
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors"><img src="https://img.shields.io/github/contributors-anon/axolotl-ai-cloud/axolotl?color=yellow&style=flat-square" alt="contributors" style="height: 20px;"></a>
|
||||
@@ -20,85 +16,46 @@
|
||||
<br/>
|
||||
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
|
||||
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
|
||||
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
|
||||
<br/>
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
</p>
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
Post-training refers to any modifications or additional training performed on
|
||||
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
|
||||
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
|
||||
techniques. With support for multiple model architectures and training configurations,
|
||||
Axolotl makes it easy to get started with these techniques.
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2026/03:
|
||||
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||
- 2026/02:
|
||||
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
|
||||
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
|
||||
- 2026/01:
|
||||
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
|
||||
- 2025/12:
|
||||
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||
|
||||
</details>
|
||||
|
||||
## ✨ Overview
|
||||
|
||||
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
|
||||
Axolotl is designed to work with YAML config files that contain everything you need to
|
||||
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
|
||||
and much more.
|
||||
|
||||
Features:
|
||||
|
||||
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
- Train various Huggingface models such as llama, pythia, falcon, mpt
|
||||
- Supports fullfinetune, lora, qlora, relora, and gptq
|
||||
- Customize configurations using a simple yaml file or CLI overwrite
|
||||
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
|
||||
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||
- And more!
|
||||
|
||||
|
||||
|
||||
## 🚀 Quick Start - LLM Fine-tuning in Minutes
|
||||
## 🚀 Quick Start
|
||||
|
||||
**Requirements**:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.9.1
|
||||
|
||||
### Google Colab
|
||||
|
||||
[](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
### Installation
|
||||
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
@@ -106,29 +63,8 @@ axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
```
|
||||
|
||||
#### Using Docker
|
||||
|
||||
Installing with Docker can be less error prone than installing in your own environment.
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
#### Cloud Providers
|
||||
|
||||
<details>
|
||||
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=github&utm_medium=developer_community&utm_campaign=template_launch_axolotl&utm_content=readme)
|
||||
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
|
||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
|
||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
|
||||
</details>
|
||||
|
||||
### Your First Fine-tune
|
||||
|
||||
```bash
|
||||
@@ -144,12 +80,19 @@ axolotl train examples/llama-3/lora-1b.yml
|
||||
|
||||
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
|
||||
|
||||
## ✨ Key Features
|
||||
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
|
||||
- **Easy Configuration**: Simple YAML files to control your training setup
|
||||
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
|
||||
- **Flexible Dataset Handling**: Use various formats and custom datasets
|
||||
- **Cloud Ready**: Run on cloud platforms or local hardware
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
|
||||
- [Configuration Guide](https://docs.axolotl.ai/docs/config-reference.html) - Full configuration options and examples
|
||||
- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources
|
||||
- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples
|
||||
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
@@ -168,31 +111,41 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
||||
|
||||
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
## 📈 Telemetry
|
||||
## Supported Models
|
||||
|
||||
Axolotl has opt-out telemetry that helps us understand how the project is being used
|
||||
and prioritize improvements. We collect basic system information, model types, and
|
||||
error rates—never personal data or file paths. Telemetry is enabled by default. To
|
||||
disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html).
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||
|
||||
✅: supported
|
||||
❌: not supported
|
||||
❓: untested
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
Thank you to our sponsors who help make Axolotl possible:
|
||||
|
||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
|
||||
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
|
||||
fine-tune large language models, run protein folding simulations, and much more.
|
||||
|
||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||
|
||||
## 📝 Citing Axolotl
|
||||
|
||||
If you use Axolotl in your research or projects, please cite it as follows:
|
||||
|
||||
```bibtex
|
||||
@software{axolotl,
|
||||
title = {Axolotl: Open Source LLM Post-Training},
|
||||
author = {{Axolotl maintainers and contributors}},
|
||||
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
||||
license = {Apache-2.0},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 📜 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
10
TODO.md
Normal file
10
TODO.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# todo list
|
||||
|
||||
- [] Validation of parameters for combinations that won't work
|
||||
|
||||
|
||||
|
||||
## things that are known not to work
|
||||
|
||||
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
|
||||
- adamw_bnb_8bit doesn't play well with FSDP offload
|
||||
115
_quarto.yml
115
_quarto.yml
@@ -1,8 +1,5 @@
|
||||
project:
|
||||
type: website
|
||||
pre-render:
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- docs/scripts/generate_examples_docs.py
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
@@ -20,9 +17,7 @@ quartodoc:
|
||||
- convert
|
||||
- prompt_tokenizers
|
||||
- logging_config
|
||||
- core.builders.base
|
||||
- core.builders.causal
|
||||
- core.builders.rl
|
||||
- core.trainer_builder
|
||||
- core.training_args
|
||||
- core.chat.messages
|
||||
- core.chat.format.chatml
|
||||
@@ -37,53 +32,24 @@ quartodoc:
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.art
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.delinearize_llama4
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.quantize
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- cli.utils
|
||||
- cli.utils.args
|
||||
- cli.utils.fetch
|
||||
- cli.utils.load
|
||||
- cli.utils.sweeps
|
||||
- cli.utils.train
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.mamba
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- core.trainers.grpo.sampler
|
||||
- core.trainers.utils
|
||||
- title: Model Loading
|
||||
desc: Functionality for loading and patching models, tokenizers, etc.
|
||||
contents:
|
||||
- loaders.model
|
||||
- loaders.tokenizer
|
||||
- loaders.processor
|
||||
- loaders.adapter
|
||||
- loaders.patch_manager
|
||||
- loaders.constants
|
||||
- title: Mixins
|
||||
desc: Mixin classes for augmenting trainers
|
||||
contents:
|
||||
- core.trainers.mixins.optimizer
|
||||
- core.trainers.mixins.rng_state_loader
|
||||
- core.trainers.mixins.scheduler
|
||||
- title: Context Managers
|
||||
desc: Context managers for altering trainer behaviors
|
||||
contents:
|
||||
- utils.ctx_managers.sequence_parallel
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
@@ -120,7 +86,7 @@ quartodoc:
|
||||
- kernels.swiglu
|
||||
- kernels.quantize
|
||||
- kernels.utils
|
||||
- title: Monkey Patches
|
||||
- title: MonkeyPatches
|
||||
desc: Runtime patches for model optimizations
|
||||
contents:
|
||||
- monkeypatch.llama_attn_hijack_flash
|
||||
@@ -128,23 +94,26 @@ quartodoc:
|
||||
- monkeypatch.mistral_attn_hijack_flash
|
||||
- monkeypatch.multipack
|
||||
- monkeypatch.relora
|
||||
- monkeypatch.llama_expand_mask
|
||||
- monkeypatch.lora_kernels
|
||||
- monkeypatch.utils
|
||||
- monkeypatch.btlm_attn_hijack_flash
|
||||
- monkeypatch.llama_patch_multipack
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.attention.mllama
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- monkeypatch.gradient_checkpointing.offload_cpu
|
||||
- monkeypatch.gradient_checkpointing.offload_disk
|
||||
- title: Utils
|
||||
desc: Utility functions
|
||||
contents:
|
||||
- utils.models
|
||||
- utils.tokenization
|
||||
- utils.chat_templates
|
||||
- utils.lora
|
||||
- utils.lora_embeddings
|
||||
- utils.model_shard_quant
|
||||
- utils.bench
|
||||
- utils.freeze
|
||||
@@ -153,9 +122,9 @@ quartodoc:
|
||||
- utils.distributed
|
||||
- utils.dict
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.streaming
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.quantization
|
||||
- utils.gradient_checkpointing.unsloth
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
@@ -205,14 +174,12 @@ quartodoc:
|
||||
- utils.callbacks.lisa
|
||||
- utils.callbacks.mlflow_
|
||||
- utils.callbacks.comet_
|
||||
- utils.callbacks.qat
|
||||
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
favicon: favicon.jpg
|
||||
|
||||
google-analytics: "G-9KYCVJBNMQ"
|
||||
|
||||
navbar:
|
||||
logo: image/axolotl_logo_digital_white.svg
|
||||
title: false
|
||||
@@ -238,52 +205,10 @@ website:
|
||||
- section: "Getting Started"
|
||||
contents:
|
||||
- docs/getting-started.qmd
|
||||
- docs/choosing_method.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- section: "Model Guides"
|
||||
contents:
|
||||
- docs/models/kimi-linear.qmd
|
||||
- docs/models/plano.qmd
|
||||
- docs/models/mimo.qmd
|
||||
- docs/models/internvl3_5.qmd
|
||||
- docs/models/olmo3.qmd
|
||||
- docs/models/trinity.qmd
|
||||
- docs/models/arcee.qmd
|
||||
- section: "Ministral3"
|
||||
contents:
|
||||
- docs/models/ministral3.qmd
|
||||
- docs/models/ministral3/think.qmd
|
||||
- docs/models/ministral3/vision.qmd
|
||||
- section: "Magistral"
|
||||
contents:
|
||||
- docs/models/magistral.qmd
|
||||
- docs/models/magistral/think.qmd
|
||||
- docs/models/magistral/vision.qmd
|
||||
- docs/models/ministral.qmd
|
||||
- docs/models/mistral-small.qmd
|
||||
- docs/models/voxtral.qmd
|
||||
- docs/models/devstral.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- docs/models/llama-4.qmd
|
||||
- docs/models/llama-2.qmd
|
||||
- docs/models/qwen3-next.qmd
|
||||
- docs/models/qwen3.qmd
|
||||
- docs/models/gemma3n.qmd
|
||||
- docs/models/apertus.qmd
|
||||
- docs/models/gpt-oss.qmd
|
||||
- docs/models/seed-oss.qmd
|
||||
- docs/models/phi.qmd
|
||||
- docs/models/smolvlm2.qmd
|
||||
- docs/models/granite4.qmd
|
||||
- docs/models/LiquidAI.qmd
|
||||
- docs/models/hunyuan.qmd
|
||||
- docs/models/jamba.qmd
|
||||
- docs/models/orpheus.qmd
|
||||
|
||||
- docs/cli.qmd
|
||||
- docs/telemetry.qmd
|
||||
- docs/config-reference.qmd
|
||||
- docs/config.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
@@ -303,26 +228,16 @@ website:
|
||||
contents:
|
||||
- docs/multimodal.qmd
|
||||
- docs/rlhf.qmd
|
||||
- docs/grpo.qmd
|
||||
- docs/ebft.qmd
|
||||
- docs/vllm_serving.qmd
|
||||
- docs/reward_modelling.qmd
|
||||
- docs/lr_groups.qmd
|
||||
- docs/lora_optims.qmd
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
- docs/optimizations.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
- docs/batch_vs_grad.qmd
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/streaming.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/optimizers.qmd
|
||||
- docs/attention.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
@@ -331,14 +246,10 @@ website:
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/nd_parallelism.qmd
|
||||
- docs/expert_quantization.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
- docs/faq.qmd
|
||||
- docs/training_stability.qmd
|
||||
- docs/debugging.qmd
|
||||
- docs/nccl.qmd
|
||||
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
|
||||
"""Original chunked implementation (reference)."""
|
||||
original_shape = logits.shape[:-1]
|
||||
num_classes = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, num_classes)
|
||||
entropies = []
|
||||
for chunk in flat_logits.split(chunk_size, dim=0):
|
||||
logps = F.log_softmax(chunk, dim=-1)
|
||||
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
|
||||
entropies.append(chunk_entropy)
|
||||
return torch.cat(entropies, dim=0).reshape(original_shape)
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, logits, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
out = fn(logits, chunk_size=128)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
del out
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, logits, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(logits, chunk_size=128)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_contiguous():
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(1, 16384),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
t_orig = profile_time(entropy_from_logits_original, logits)
|
||||
t_triton = profile_time(entropy_from_logits, logits)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(entropy_from_logits_original, logits)
|
||||
m_triton = profile_memory(entropy_from_logits, logits)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_noncontiguous():
|
||||
print("\n" + "=" * 60)
|
||||
print(
|
||||
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(4, 2048, "transpose"),
|
||||
(4, 8192, "transpose"),
|
||||
(8, 2048, "transpose"),
|
||||
(4, 4096, "slice_batch"),
|
||||
]
|
||||
|
||||
for B, L, method in configs:
|
||||
torch.manual_seed(42)
|
||||
|
||||
if method == "transpose":
|
||||
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw.transpose(0, 1)
|
||||
raw_gb = L * B * V * 2 / 1e9
|
||||
elif method == "slice_batch":
|
||||
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw[::2]
|
||||
raw_gb = B * 2 * L * V * 2 / 1e9
|
||||
else:
|
||||
continue
|
||||
|
||||
if raw_gb > 28:
|
||||
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
|
||||
del raw, logits_nc
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
def original_with_copy(logits, chunk_size=128):
|
||||
return entropy_from_logits_original(
|
||||
logits.contiguous(), chunk_size=chunk_size
|
||||
)
|
||||
|
||||
t_orig = profile_time(original_with_copy, logits_nc)
|
||||
t_triton = profile_time(entropy_from_logits, logits_nc)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" orig+copy: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton-strided:{fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(original_with_copy, logits_nc)
|
||||
m_triton = profile_memory(entropy_from_logits, logits_nc)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" orig+copy: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton-strided:{fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del raw, logits_nc
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_contiguous()
|
||||
benchmark_noncontiguous()
|
||||
@@ -1,284 +0,0 @@
|
||||
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||
and full fwd+bwd autograd throughput.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
WARMUP = 5
|
||||
ITERS = 20
|
||||
|
||||
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||
|
||||
BUILTIN_CONFIGS = {
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_config(spec):
|
||||
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||
key = spec.lower().replace("/", "-")
|
||||
for name, cfg in BUILTIN_CONFIGS.items():
|
||||
if key in name.lower() or name.lower() in key:
|
||||
return name, cfg
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||
tc = hf_cfg.get_text_config()
|
||||
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||
hf_cfg = tc
|
||||
hidden = hf_cfg.hidden_size
|
||||
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
experts = (
|
||||
getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None)
|
||||
)
|
||||
top_k = (
|
||||
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||
or 2
|
||||
)
|
||||
name = spec.split("/")[-1]
|
||||
return name, (experts, hidden, inter, top_k)
|
||||
|
||||
|
||||
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _clean():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
times.sort()
|
||||
return times[len(times) // 2]
|
||||
|
||||
|
||||
def _setup(num_experts, K, N, T, top_k, R):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||
|
||||
|
||||
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||
|
||||
|
||||
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _call_base(x, W, sei, ssi, top_k):
|
||||
return base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
|
||||
|
||||
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||
return lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=num_experts,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
"-m",
|
||||
nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)",
|
||||
)
|
||||
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
|
||||
T = args.seq_len
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"T={T}, ranks={args.ranks}\n")
|
||||
|
||||
if args.models:
|
||||
configs = [_resolve_config(m) for m in args.models]
|
||||
else:
|
||||
configs = list(BUILTIN_CONFIGS.items())
|
||||
|
||||
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||
print(f"{'=' * 70}")
|
||||
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
for R in args.ranks:
|
||||
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||
num_experts, K, N, T, top_k, R
|
||||
)
|
||||
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = (
|
||||
"split"
|
||||
if (
|
||||
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
)
|
||||
else "fused"
|
||||
)
|
||||
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||
|
||||
total = t_fwd + t_dx + t_bwd
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(
|
||||
f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead * 100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms"
|
||||
)
|
||||
|
||||
# Full autograd fwd+bwd with memory measurement
|
||||
x_ag = x.clone().requires_grad_(True)
|
||||
lA_ag = lA.clone().requires_grad_(True)
|
||||
lB_ag = lB.clone().requires_grad_(True)
|
||||
|
||||
def _run_autograd(
|
||||
_x=x_ag,
|
||||
_W=W,
|
||||
_k=top_k,
|
||||
_sei=sei,
|
||||
_ssi=ssi,
|
||||
_eo=eo,
|
||||
_lA=lA_ag,
|
||||
_lB=lB_ag,
|
||||
):
|
||||
out = ScatterMoELoRA.apply(
|
||||
_x,
|
||||
_W,
|
||||
_k,
|
||||
_sei,
|
||||
_ssi,
|
||||
_eo,
|
||||
_lA,
|
||||
_lB,
|
||||
2.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
_x.grad = None
|
||||
_lA.grad = None
|
||||
_lB.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(
|
||||
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import (
|
||||
selective_log_softmax,
|
||||
selective_log_softmax_original,
|
||||
)
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, args, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
fn(*args)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, args, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(*args)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_forward():
|
||||
print("=" * 60)
|
||||
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(selective_log_softmax_original, (logits, index))
|
||||
t_triton = profile_time(selective_log_softmax, (logits, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
|
||||
m_triton = profile_memory(selective_log_softmax, (logits, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_backward():
|
||||
print("\n" + "=" * 60)
|
||||
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
def fwd_bwd_original(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax_original(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
def fwd_bwd_triton(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 20:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits_orig = torch.randn(
|
||||
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
logits_tri = logits_orig.detach().clone().requires_grad_(True)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
|
||||
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" FWD+BWD TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
|
||||
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" FWD+BWD MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits_orig, logits_tri, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_forward()
|
||||
benchmark_backward()
|
||||
@@ -1,54 +0,0 @@
|
||||
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py --uv | sh
|
||||
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
@@ -1,6 +1,6 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
@@ -9,10 +9,9 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -32,8 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
75
cicd/cicd.sh
75
cicd/cicd.sh
@@ -3,71 +3,10 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
set -o pipefail
|
||||
for i in 1 2 3; do
|
||||
if curl --silent --show-error --fail -L \
|
||||
https://axolotl-ci.b-cdn.net/hf-cache.tar.zst \
|
||||
| tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1; then
|
||||
echo "HF cache extracted successfully"
|
||||
break
|
||||
fi
|
||||
echo "Attempt $i failed, cleaning up and retrying in 15s..."
|
||||
rm -rf "${HF_HOME}/hub/"*
|
||||
sleep 15
|
||||
done
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
# hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
# hf download "microsoft/Phi-3-medium-128k-instruct"
|
||||
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
--ignore=tests/e2e/ \
|
||||
--ignore=tests/patched/ \
|
||||
--ignore=tests/cli \
|
||||
/workspace/axolotl/tests/ \
|
||||
--cov=axolotl
|
||||
|
||||
# Run lora kernels tests with coverage append
|
||||
pytest -v --durations=10 \
|
||||
/workspace/axolotl/tests/e2e/patched/lora_kernels \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run patched tests excluding lora kernels with coverage append
|
||||
pytest --full-trace -vvv --durations=10 \
|
||||
--ignore=tests/e2e/patched/lora_kernels \
|
||||
/workspace/axolotl/tests/e2e/patched \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run solo tests with coverage append
|
||||
pytest -v --durations=10 -n1 \
|
||||
/workspace/axolotl/tests/e2e/solo/ \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run integration tests with coverage append
|
||||
pytest -v --durations=10 \
|
||||
/workspace/axolotl/tests/e2e/integrations/ \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/cli \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
# Run remaining e2e tests with coverage append and final report
|
||||
pytest -v --durations=10 \
|
||||
--ignore=tests/e2e/solo/ \
|
||||
--ignore=tests/e2e/patched/ \
|
||||
--ignore=tests/e2e/multigpu/ \
|
||||
--ignore=tests/e2e/integrations/ \
|
||||
--ignore=tests/cli \
|
||||
/workspace/axolotl/tests/e2e/ \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:e2e-coverage.xml
|
||||
|
||||
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true
|
||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Modal app to run axolotl GPU cleanup"""
|
||||
|
||||
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
def cleanup():
|
||||
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
cleanup.remote()
|
||||
@@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# cleanup old cache files for datasets processing and intermediate mappings
|
||||
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
|
||||
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;
|
||||
@@ -1,12 +1,74 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=120 * 60, # 90 min
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
@@ -17,22 +19,17 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
@@ -57,7 +54,7 @@ VOLUME_CONFIG = {
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||
GPU_CONFIG = f"H100:{N_GPUS}"
|
||||
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
@@ -65,14 +62,14 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
exit(exit_code)
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=120 * 60,
|
||||
cpu=16.0,
|
||||
timeout=90 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072 * N_GPUS,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
|
||||
@@ -1,25 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v --durations=10 -n2 --maxfail=3 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
--cov=axolotl
|
||||
|
||||
# Run solo tests with coverage append
|
||||
pytest -v --durations=10 -n1 \
|
||||
/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--cov=axolotl \
|
||||
--cov-append
|
||||
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:multigpu-coverage.xml
|
||||
|
||||
# Upload coverage to Codecov if CODECOV_TOKEN is available
|
||||
if [ -n "$CODECOV_TOKEN" ]; then
|
||||
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
|
||||
fi
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import jinja2
|
||||
import modal
|
||||
import modal.experimental
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = modal.experimental.raw_dockerfile_image(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
# context_mount=None,
|
||||
force_build=True,
|
||||
# gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
|
||||
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
58
codecov.yml
58
codecov.yml
@@ -1,58 +0,0 @@
|
||||
codecov:
|
||||
require_ci_to_pass: yes
|
||||
notify:
|
||||
wait_for_ci: true
|
||||
|
||||
coverage:
|
||||
precision: 2
|
||||
round: down
|
||||
range: "70...100"
|
||||
status:
|
||||
project:
|
||||
default:
|
||||
# basic
|
||||
target: auto
|
||||
threshold: 1%
|
||||
base: auto
|
||||
# advanced
|
||||
branches: null
|
||||
if_no_uploads: error
|
||||
if_not_found: success
|
||||
if_ci_failed: error
|
||||
only_pulls: true
|
||||
flags: null
|
||||
paths: null
|
||||
informational: true
|
||||
patch:
|
||||
default:
|
||||
# basic
|
||||
target: auto
|
||||
threshold: 1%
|
||||
base: auto
|
||||
# advanced
|
||||
branches: null
|
||||
if_no_uploads: error
|
||||
if_not_found: success
|
||||
if_ci_failed: error
|
||||
only_pulls: false
|
||||
flags: null
|
||||
paths: null
|
||||
informational: true
|
||||
|
||||
parsers:
|
||||
gcov:
|
||||
branch_detection:
|
||||
conditional: yes
|
||||
loop: yes
|
||||
method: no
|
||||
macro: no
|
||||
|
||||
comment:
|
||||
layout: "reach,diff,flags,files,footer"
|
||||
behavior: default
|
||||
require_changes: no
|
||||
require_base: no
|
||||
require_head: yes
|
||||
|
||||
github_checks:
|
||||
annotations: false
|
||||
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"compile": {
|
||||
"disable": false,
|
||||
"backend": "inductor"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -7,9 +7,9 @@
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"max_live_parameters": 0,
|
||||
"max_reuse_distance": 0,
|
||||
"gather_16bit_weights_on_model_save": true
|
||||
"stage3_max_live_parameters": 0,
|
||||
"stage3_max_reuse_distance": 0,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"max_live_parameters": 0,
|
||||
"max_reuse_distance": 0,
|
||||
"gather_16bit_weights_on_model_save": true
|
||||
"stage3_max_live_parameters": 0,
|
||||
"stage3_max_reuse_distance": 0,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
|
||||
@@ -17,9 +17,9 @@
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"max_live_parameters": 0,
|
||||
"max_reuse_distance": 0,
|
||||
"gather_16bit_weights_on_model_save": true
|
||||
"stage3_max_live_parameters": 0,
|
||||
"stage3_max_reuse_distance": 0,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"max_live_parameters": 0,
|
||||
"max_reuse_distance": 0,
|
||||
"gather_16bit_weights_on_model_save": true
|
||||
"stage3_max_live_parameters": 0,
|
||||
"stage3_max_reuse_distance": 0,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
|
||||
@@ -13,7 +13,7 @@ datasets:
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
dataset_prepared_path: temp_debug/axolotl_outputs/data
|
||||
dataset_num_proc: 1
|
||||
dataset_processes: 1
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: false
|
||||
|
||||
@@ -6,14 +6,11 @@ ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -21,27 +18,22 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \ python scripts/unsloth_install.py | sh && \
|
||||
python scripts/cutcrossentropy_install.py | sh && \
|
||||
pip install pytest && \
|
||||
pip cache purge
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works with shallow clone
|
||||
RUN python scripts/unsloth_install.py | sh
|
||||
RUN python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install pytest
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch && \
|
||||
git config --global credential.helper store
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
||||
RUN chmod +x /root/.axolotl-complete.bash && \
|
||||
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
|
||||
@@ -2,75 +2,38 @@ ARG CUDA_VERSION="11.8.0"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTHON_VERSION="3.10"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG CUDA="128"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
||||
ibverbs-providers ibverbs-utils infiniband-diags \
|
||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||
&& rm -rf /var/cache/apt/archives \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
MINICONDA_ARCH="x86_64"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
MINICONDA_ARCH="aarch64"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi \
|
||||
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
RUN if [ "$CUDA" != "130" ] ; then \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
python3 -m pip cache purge; \
|
||||
fi
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
|
||||
@@ -22,22 +22,18 @@ RUN apt-get update \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
|
||||
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
python3 -m pip cache purge
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
@@ -14,10 +14,7 @@ COPY scripts/motd /etc/motd
|
||||
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
|
||||
@@ -9,15 +9,13 @@ ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
EXPOSE 8888
|
||||
EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/motd /etc/motd
|
||||
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
|
||||
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl-uv:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
EXPOSE 8888
|
||||
EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/motd /etc/motd
|
||||
|
||||
RUN uv pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||
CMD ["sleep", "infinity"]
|
||||
@@ -1,48 +0,0 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base-uv:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py --uv | sh && \
|
||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||
uv pip install pytest && \
|
||||
uv cache clean
|
||||
|
||||
# fix so that git fetch/pull from remote works with shallow clone
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch && \
|
||||
git config --global credential.helper store
|
||||
|
||||
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
||||
RUN chmod +x /root/.axolotl-complete.bash && \
|
||||
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
||||
@@ -1,57 +0,0 @@
|
||||
ARG CUDA_VERSION="12.6.3"
|
||||
ARG CUDNN_VERSION=""
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.6.0"
|
||||
ARG CUDA="126"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
ENV UV_TORCH_BACKEND="cu${CUDA}"
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
|
||||
&& git lfs install --skip-repo \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
ENV PATH="/root/.local/bin:${PATH}"
|
||||
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN uv venv --no-project --relocatable axolotl-venv
|
||||
|
||||
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
||||
fi
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
LINUX_TAG="manylinux_" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
|
||||
arm64) ARCH_TAG="2_34_aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
3
docs/.gitignore
vendored
3
docs/.gitignore
vendored
@@ -2,6 +2,3 @@
|
||||
_site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
config-reference.qmd
|
||||
models/**/*.qmd
|
||||
models/**/*.html
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# GRPO — Agent Reference
|
||||
|
||||
Online RL with verifiable reward functions. For full config reference, async features, and scaling, see [grpo.qmd](../grpo.qmd). For vLLM setup, see [vllm_serving.qmd](../vllm_serving.qmd).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
|
||||
┌──────────────────────┐ ┌──────────────────────────────────┐
|
||||
│ vLLM Server │ HTTP │ Trainer │
|
||||
│ Serves base model │◄────────────►│ 1. Send prompts to vLLM │
|
||||
│ + LoRA adapter │ /generate │ 2. Score completions (rewards) │
|
||||
│ │ /set_lora │ 3. Compute advantages │
|
||||
│ Punica kernels for │ │ 4. PPO-clip gradient update │
|
||||
│ LoRA inference │ │ 5. Sync LoRA weights to vLLM │
|
||||
└──────────────────────┘ └──────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Components Required
|
||||
|
||||
1. A YAML config with `rl: grpo`
|
||||
2. A reward module (Python file with reward functions)
|
||||
3. A running vLLM server (`axolotl vllm-serve config.yaml`)
|
||||
|
||||
## Reward Function Signature
|
||||
|
||||
```python
|
||||
def my_reward(completions, **kwargs) -> list[float]:
|
||||
# completions[i][0]["content"] = text of i-th completion
|
||||
# **kwargs contains dataset columns not removed by transform
|
||||
return [score_for_each_completion]
|
||||
```
|
||||
|
||||
Multiple rewards: `reward_funcs: [r1, r2]` with `reward_weights: [1.0, 0.5]`.
|
||||
|
||||
## Key Async Features
|
||||
|
||||
| Feature | Config | Purpose |
|
||||
|---------|--------|---------|
|
||||
| Async prefetch | `async_prefetch: true` | Overlap generation with training |
|
||||
| LoRA sync | `vllm_lora_sync: true` | Fast adapter sync via filesystem |
|
||||
| Streaming scoring | `streaming_partial_batch: true` | Score one group at a time |
|
||||
| Zero-adv skip | `skip_zero_advantage_batches: true` | Skip batches with no learning signal |
|
||||
| Replay buffer | `replay_buffer_size: 100` | Cache high-signal groups |
|
||||
| IS correction | `vllm_importance_sampling_correction: true` | Fix off-policy distribution shift |
|
||||
|
||||
## Health Checks
|
||||
|
||||
- `rewards/*/mean` > 0.15 within 20 steps (else: test reward function standalone)
|
||||
- `reward_std` > 0 on most steps (else: no learning signal)
|
||||
- `entropy` 0.05-0.5 (< 0.01 = mode collapse)
|
||||
- `grad_norm` 0.001-1.0 (> 10 = unstable, 0.0 = zero-advantage skip)
|
||||
|
||||
See [training_stability.qmd](../training_stability.qmd) for detailed diagnostics.
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
cli/train.py # Entry point
|
||||
cli/vllm_serve.py # Entry point for vLLM server
|
||||
core/trainers/grpo/
|
||||
trainer.py # AxolotlGRPOTrainer
|
||||
sampler.py # Sampling utilities
|
||||
core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer
|
||||
scripts/vllm_serve_lora.py # vLLM serve script with LoRA sync support
|
||||
utils/schemas/trl.py # TRL config schema (all trl: options)
|
||||
|
||||
docs/grpo.qmd # Full user docs: async, rewards, scaling, config reference
|
||||
docs/vllm_serving.qmd # vLLM server modes, LoRA sync, weight sync
|
||||
```
|
||||
@@ -1,121 +0,0 @@
|
||||
# Preference Learning (RLHF) — Agent Reference
|
||||
|
||||
Reference for DPO, IPO, KTO, ORPO, and SimPO. For config templates and dataset format examples, see [rlhf.qmd](../rlhf.qmd). For GRPO, see [grpo.qmd](../grpo.qmd). For EBFT, see [ebft.qmd](../ebft.qmd).
|
||||
|
||||
## Method Overview
|
||||
|
||||
| Method | Data Requirement | Key Idea | Best For |
|
||||
|--------|-----------------|----------|----------|
|
||||
| **DPO** | Paired (chosen + rejected) | Implicit reward via preference pairs | General alignment, most common |
|
||||
| **IPO** | Paired (chosen + rejected) | DPO with different loss (avoids overfitting) | When DPO overfits |
|
||||
| **KTO** | Unpaired (completion + binary label) | Kahneman-Tversky loss, no pairs needed | When you only have thumbs-up/down |
|
||||
| **ORPO** | Paired (chosen + rejected) | Combined SFT + preference, no ref model | Single-stage alignment, saves VRAM |
|
||||
| **SimPO** | Paired (chosen + rejected) | Length-normalized, no ref model | Simple setup, length-robust |
|
||||
|
||||
Default: start with DPO. All methods require `sample_packing: false`.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────┐ ┌───────────────┐ ┌───────────────┐
|
||||
│ Policy Model │ │ Reference │ │ Preference │
|
||||
│ (trainable) │ │ Model (frozen)│ │ Dataset │
|
||||
└──────┬───────┘ └──────┬────────┘ └──────┬────────┘
|
||||
└──────────┬───────┘ │
|
||||
v │
|
||||
Forward pass on chosen + rejected <─────┘
|
||||
│
|
||||
Preference Loss (DPO/IPO/KTO/...)
|
||||
│
|
||||
Backprop + Update
|
||||
|
||||
Exception: ORPO and SimPO do NOT use a reference model (~50% less VRAM).
|
||||
```
|
||||
|
||||
No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference data.
|
||||
|
||||
## Method Selection
|
||||
|
||||
1. Paired preference data (chosen + rejected)?
|
||||
- Default → `rl: dpo`
|
||||
- Overfitting → `rl: ipo`
|
||||
- VRAM-limited → `rl: orpo` (no ref model)
|
||||
- Length-sensitive → `rl: simpo` (no ref model)
|
||||
2. Only binary labels (good/bad)? → `rl: kto`
|
||||
3. Single-stage training (no separate SFT)? → `rl: orpo`
|
||||
|
||||
| | DPO | IPO | KTO | ORPO | SimPO |
|
||||
|---|---|---|---|---|---|
|
||||
| **Reference model** | Yes | Yes | Yes | No | No |
|
||||
| **VRAM overhead** | ~2x model | ~2x model | ~2x model | ~1x model | ~1x model |
|
||||
| **TRL trainer class** | DPOTrainer | DPOTrainer | KTOTrainer | ORPOTrainer | CPOTrainer |
|
||||
|
||||
## Prompt Strategy Resolution
|
||||
|
||||
The `type` field resolves to a Python function:
|
||||
|
||||
```
|
||||
type: "chatml.intel"
|
||||
→ axolotl.prompt_strategies.dpo.chatml.intel(cfg, **kwargs)
|
||||
→ returns transform_fn(sample) → {"prompt", "chosen", "rejected"}
|
||||
|
||||
type: "chat_template.default"
|
||||
→ axolotl.prompt_strategies.dpo.chat_template.default(cfg, dataset_idx, **kwargs)
|
||||
|
||||
type: {"field_prompt": "prompt", ...} (dict)
|
||||
→ axolotl.prompt_strategies.dpo.user_defined.default(...)
|
||||
```
|
||||
|
||||
Module base: `axolotl.prompt_strategies.{rl_method}` — replace `dpo` with `kto` or `orpo`.
|
||||
|
||||
## Healthy Training Indicators
|
||||
|
||||
| Metric | Healthy Range | Problem |
|
||||
|--------|--------------|---------|
|
||||
| `train/loss` | Decreasing, 0.3-0.7 | Flat or increasing = broken data or too high LR |
|
||||
| `rewards/chosen` | Increasing | Flat = model not learning preferences |
|
||||
| `rewards/rejected` | Decreasing | Increasing = model prefers wrong responses |
|
||||
| `rewards/margins` | Positive and increasing | Negative = prefers rejected over chosen |
|
||||
| `rewards/accuracies` | > 0.5, toward 0.7+ | < 0.5 = worse than random |
|
||||
| `logps/rejected` | Decreasing | Increasing = reward hacking |
|
||||
| `grad_norm` | 0.01 - 10.0 | > 100 = exploding gradients |
|
||||
|
||||
Method-specific: DPO/IPO watch `rewards/margins`; KTO loss is noisier; ORPO monitor SFT + odds ratio components; SimPO check length-normalized reward separation.
|
||||
|
||||
## Known Issues
|
||||
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| Sample packing crash | Set `sample_packing: false` (required for all preference methods) |
|
||||
| KTO `KeyError: 'label'` | Ensure dataset has boolean `label` column |
|
||||
| ORPO/KTO `KeyError` during tokenization | Add `remove_unused_columns: false` |
|
||||
| ORPO template not applied | ORPO requires explicit `chat_template` setting |
|
||||
| OOM with ref model (DPO/IPO/KTO) | Use LoRA/QLoRA, or switch to ORPO/SimPO (no ref model) |
|
||||
| IPO + label_smoothing | Do not set `dpo_label_smoothing` when `rl: ipo` |
|
||||
|
||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd)
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
core/trainers/dpo/ # DPO trainer, args, strategy
|
||||
core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer class
|
||||
core/training_args.py # AxolotlKTOConfig, AxolotlORPOConfig, AxolotlCPOConfig
|
||||
prompt_strategies/
|
||||
dpo/ # DPO/IPO/SimPO dataset strategies
|
||||
chat_template.py # chat_template.default, chat_template.argilla_chat
|
||||
chatml.py # chatml.default/intel/icr/argilla_chat/prompt_pairs/ultra
|
||||
llama3.py # llama3 variants (same subtypes as chatml)
|
||||
user_defined.py # Custom field mapping
|
||||
passthrough.py # No transform
|
||||
kto/ # KTO dataset strategies (chatml, llama3, user_defined)
|
||||
orpo/ # ORPO dataset strategies (chat_template.argilla)
|
||||
utils/schemas/enums.py # RLType enum (dpo, ipo, kto, orpo, simpo, grpo, gdpo, ebft)
|
||||
utils/schemas/config.py # All rl/dpo/kto/orpo/simpo config fields
|
||||
|
||||
docs/rlhf.qmd # Full user docs: all dataset formats, config templates
|
||||
docs/choosing_method.qmd # SFT vs DPO vs GRPO decision guide
|
||||
examples/qwen2/dpo.yaml # DPO example
|
||||
examples/llama-3/qlora-1b-kto.yaml # KTO example
|
||||
```
|
||||
@@ -1,75 +0,0 @@
|
||||
# Pretraining / Continual Pretraining — Agent Reference
|
||||
|
||||
Train on raw text with no input masking. Two approaches depending on dataset size.
|
||||
|
||||
## When to Use
|
||||
|
||||
- Continual pretraining on domain-specific corpora
|
||||
- Adapting a base model to a new language or domain before fine-tuning
|
||||
- Pretraining-style data where the entire text is the training signal
|
||||
|
||||
## Choosing an Approach
|
||||
|
||||
| | Non-streaming (`type: completion`) | Streaming (`pretraining_dataset`) |
|
||||
|---|---|---|
|
||||
| **Dataset size** | Fits in memory | Too large to fit in memory |
|
||||
| **Tokenization** | Pre-tokenized before training | On-demand during training |
|
||||
| **Config key** | `datasets:` | `pretraining_dataset:` |
|
||||
| **Long text handling** | Splits texts exceeding `sequence_len` | Concatenates into fixed-length sequences |
|
||||
| **Benefit** | Can preprocess on CPU, transfer to GPU | Start training immediately, no preprocessing |
|
||||
|
||||
## Non-Streaming: `type: completion`
|
||||
|
||||
For smaller datasets that fit in memory. Pre-tokenizes the entire dataset.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: my_corpus
|
||||
type: completion
|
||||
# field: text # Column name (default: "text")
|
||||
```
|
||||
|
||||
## Streaming: `pretraining_dataset`
|
||||
|
||||
For large corpora. Streams data on-demand without loading everything into memory.
|
||||
|
||||
```yaml
|
||||
pretraining_dataset:
|
||||
- path: HuggingFaceFW/fineweb-edu
|
||||
type: pretrain
|
||||
text_column: text
|
||||
split: train
|
||||
|
||||
max_steps: 1000 # Required — axolotl can't infer dataset size
|
||||
streaming_multipack_buffer_size: 10000 # Buffer for sample packing
|
||||
pretrain_multipack_attn: true # Prevent cross-attention between packed samples
|
||||
```
|
||||
|
||||
`max_steps` is required for streaming — one step = `sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus` tokens.
|
||||
|
||||
Full streaming docs: [streaming.qmd](../streaming.qmd)
|
||||
|
||||
## Dataset Format
|
||||
|
||||
```json
|
||||
{"text": "The complete document text goes here."}
|
||||
```
|
||||
|
||||
## Key Settings
|
||||
|
||||
- `sample_packing: true` + `pad_to_sequence_len: true` — pack documents into fixed-length sequences
|
||||
- `flash_attention: true` — required for sample packing
|
||||
- No adapter — typically full fine-tune for pretraining
|
||||
- `train_on_inputs: true` — default for completion (all tokens trained on)
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
prompt_strategies/completion.py # Non-streaming: completion prompt strategy (no masking)
|
||||
utils/data/sft.py # Non-streaming: dataset loading and processing
|
||||
utils/data/streaming.py # Streaming: encode_streaming(), wrap_streaming_dataset()
|
||||
utils/schemas/config.py # Config fields: pretraining_dataset, pretrain_multipack_attn, etc.
|
||||
|
||||
examples/streaming/pretrain.yaml # Full streaming pretraining example config
|
||||
```
|
||||
@@ -1,48 +0,0 @@
|
||||
# Reward Modelling — Agent Reference
|
||||
|
||||
Train models to score responses for use as reward signals in RL. For full docs, see [reward_modelling.qmd](../reward_modelling.qmd).
|
||||
|
||||
## Types
|
||||
|
||||
### Outcome Reward Models (ORM)
|
||||
|
||||
Train a classifier to predict preference over entire interactions. Uses `AutoModelForSequenceClassification`.
|
||||
|
||||
```yaml
|
||||
base_model: google/gemma-2-2b
|
||||
model_type: AutoModelForSequenceClassification
|
||||
num_labels: 1
|
||||
reward_model: true
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
```
|
||||
|
||||
Dataset format: `{"system": "...", "input": "...", "chosen": "...", "rejected": "..."}`
|
||||
|
||||
### Process Reward Models (PRM)
|
||||
|
||||
Train a token classifier to score each reasoning step. Uses `AutoModelForTokenClassification`.
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-3B
|
||||
model_type: AutoModelForTokenClassification
|
||||
num_labels: 2
|
||||
process_reward_model: true
|
||||
datasets:
|
||||
- path: trl-lib/math_shepherd
|
||||
type: stepwise_supervised
|
||||
```
|
||||
|
||||
Dataset format: see [stepwise_supervised.qmd](../dataset-formats/stepwise_supervised.qmd).
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
core/builders/causal.py # Handles reward_model flag in trainer builder
|
||||
prompt_strategies/bradley_terry/ # Bradley-Terry prompt strategies
|
||||
prompt_strategies/stepwise_supervised.py # PRM dataset strategy
|
||||
utils/schemas/config.py # reward_model, process_reward_model config fields
|
||||
```
|
||||
@@ -1,115 +0,0 @@
|
||||
# SFT — Agent Reference
|
||||
|
||||
Supervised fine-tuning pipeline reference. For config templates and dataset format examples, see [getting-started.qmd](../getting-started.qmd) and [dataset-formats/](../dataset-formats/).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
YAML Config → axolotl train config.yaml
|
||||
|
||||
1. Load base model (+ quantization if QLoRA/8-bit)
|
||||
2. Apply adapter layers (LoRA/QLoRA) if configured
|
||||
3. Load + tokenize dataset(s)
|
||||
- Apply prompt template (chat_template / alpaca / custom)
|
||||
- Mask inputs (train_on_inputs: false)
|
||||
- Pack samples into sequences (sample_packing: true)
|
||||
4. Training loop (HuggingFace Trainer)
|
||||
- forward → loss → backward → optimizer step → lr scheduler step
|
||||
5. Save model / adapter weights + tokenizer
|
||||
|
||||
Multi-GPU: FSDP or DeepSpeed shards model across GPUs automatically.
|
||||
```
|
||||
|
||||
## Components Required
|
||||
|
||||
1. A YAML config — model, dataset(s), adapter settings, hyperparameters
|
||||
2. A dataset — HuggingFace Hub, local JSONL/JSON/Parquet, or S3/GCS path
|
||||
3. (Optional) A custom prompt strategy — for non-standard dataset formats
|
||||
|
||||
No external server processes needed (unlike GRPO which requires vLLM).
|
||||
|
||||
## Dataset Format Decision Tree
|
||||
|
||||
```
|
||||
Is your data in chat/message format?
|
||||
├─ YES: OpenAI message format (role/content)?
|
||||
│ ├─ YES ──────────────────────> type: chat_template (recommended)
|
||||
│ └─ NO (custom field names) ──> type: chat_template + message_property_mappings
|
||||
└─ NO: Instruction/response pairs?
|
||||
├─ YES ──> type: alpaca (instruction, input, output)
|
||||
└─ NO: Raw text?
|
||||
├─ YES with segments ─────> type: input_output (template-free masking)
|
||||
└─ YES continuous ────────> type: completion (pretraining-style)
|
||||
```
|
||||
|
||||
Full format specs: [dataset-formats/](../dataset-formats/)
|
||||
|
||||
## Model Size to Adapter Choice
|
||||
|
||||
| Model Size | LoRA | QLoRA (4-bit) | Full Fine-Tune | VRAM (approx) |
|
||||
|-----------|------|---------------|----------------|---------------|
|
||||
| 1-3B | Preferred | Low-budget option | Single GPU OK | 8-16 GB (LoRA) |
|
||||
| 7-8B | Preferred | Good balance | Needs multi-GPU | 16-24 GB (LoRA) |
|
||||
| 13-14B | Preferred | Good balance | Multi-GPU required | 24-40 GB (LoRA) |
|
||||
| 30-70B | LoRA or QLoRA | Preferred for single GPU | Multi-node | 40-80 GB (QLoRA) |
|
||||
|
||||
## Hyperparameter Ranges
|
||||
|
||||
| Parameter | LoRA | QLoRA | Full FT |
|
||||
|-----------|------|-------|---------|
|
||||
| `learning_rate` | 1e-4 to 3e-4 | 1e-4 to 3e-4 | 1e-5 to 5e-5 |
|
||||
| `lora_r` | 16-64 | 16-64 | N/A |
|
||||
| `lora_alpha` | 1-2x `lora_r` | 1-2x `lora_r` | N/A |
|
||||
| `micro_batch_size` | 2-8 | 2-4 | 1-2 |
|
||||
| `gradient_accumulation_steps` | 2-8 | 4-16 | 4-16 |
|
||||
| `num_epochs` | 1-3 | 1-3 | 1-3 |
|
||||
| `optimizer` | `adamw_8bit` | `adamw_bnb_8bit` | `adamw_torch_fused` |
|
||||
|
||||
Effective batch = micro_batch * grad_accum * num_gpus. Lower LR for larger models.
|
||||
|
||||
## Healthy Training Indicators
|
||||
|
||||
| Metric | Healthy | Problem |
|
||||
|--------|---------|---------|
|
||||
| `train_loss` | Decreasing, starting ~2-4 for chat models | Flat or increasing from step 1 — data or LR issue |
|
||||
| `eval_loss` | Decreasing, tracks train_loss | Increasing while train_loss decreases — overfitting |
|
||||
| `grad_norm` | 0.1-10, relatively stable | Spikes >100 — instability. 0.0 — frozen weights |
|
||||
| `learning_rate` | Follows scheduler curve | Flat or NaN — config issue |
|
||||
|
||||
Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See [training_stability.qmd](../training_stability.qmd).
|
||||
|
||||
## Known Issues
|
||||
|
||||
| Issue | Fix |
|
||||
|-------|-----|
|
||||
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
|
||||
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
||||
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
||||
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
||||
| Tokenizer pad token / infinite loss | Set `special_tokens: pad_token: "<\|end_of_text\|>"` |
|
||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
|
||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||
|
||||
## File Map
|
||||
|
||||
```
|
||||
src/axolotl/
|
||||
cli/train.py # Entry point for `axolotl train`
|
||||
cli/preprocess.py # Entry point for `axolotl preprocess`
|
||||
core/builders/causal.py # HFCausalTrainerBuilder — wires config → SFT trainer
|
||||
core/trainers/base.py # AxolotlTrainer — base trainer class
|
||||
core/trainers/mixins/ # Packing, optimizer, scheduler, checkpoints
|
||||
prompt_strategies/ # Format handlers: chat_template, alpaca, completion, input_output
|
||||
utils/schemas/config.py # AxolotlInputConfig — main config schema
|
||||
utils/schemas/datasets.py # SFTDataset, DatasetConfig
|
||||
utils/schemas/peft.py # LoraConfig — LoRA parameters
|
||||
integrations/liger/ # Liger kernel plugin
|
||||
|
||||
examples/llama-3/ # LoRA, QLoRA, full FT example configs
|
||||
docs/getting-started.qmd # Quickstart with config templates
|
||||
docs/optimizations.qmd # Flash attention, gradient checkpointing, sample packing
|
||||
docs/multi-gpu.qmd # FSDP and DeepSpeed setup
|
||||
```
|
||||
@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
|
||||
Download a base model using the Hugging Face CLI:
|
||||
|
||||
```bash
|
||||
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
||||
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
|
||||
```
|
||||
|
||||
### 10. Create Axolotl Configuration
|
||||
|
||||
@@ -1,178 +0,0 @@
|
||||
---
|
||||
title: Attention
|
||||
description: Supported attention modules in Axolotl
|
||||
---
|
||||
|
||||
## SDP Attention
|
||||
|
||||
This is the default built-in attention in PyTorch.
|
||||
|
||||
```yaml
|
||||
sdp_attention: true
|
||||
```
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention
|
||||
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
```
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Flash Attention 2
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
||||
Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/hopper
|
||||
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
# Remove it so Python can find the real FA4 module:
|
||||
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
||||
for training on Hopper, install from source using the instructions above.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
||||
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
||||
and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||
|
||||
## Flex Attention
|
||||
|
||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
||||
|
||||
```yaml
|
||||
flex_attention: true
|
||||
|
||||
# recommended
|
||||
torch_compile: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We recommend using latest stable version of PyTorch for best performance.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
||||
|
||||
## SageAttention
|
||||
|
||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
||||
|
||||
```yaml
|
||||
sage_attention: true
|
||||
```
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
```bash
|
||||
pip install sageattention==2.2.0 --no-build-isolation
|
||||
```
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
||||
|
||||
:::
|
||||
|
||||
|
||||
## xFormers
|
||||
|
||||
```yaml
|
||||
xformers_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
We recommend using with Turing GPUs or below (such as on Colab).
|
||||
|
||||
:::
|
||||
|
||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
||||
|
||||
## Shifted Sparse Attention
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
||||
|
||||
:::
|
||||
|
||||
Requirements: LLaMA model architecture
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
s2_attention: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
No sample packing support!
|
||||
|
||||
:::
|
||||
@@ -1,86 +0,0 @@
|
||||
---
|
||||
title: "Checkpoint Saving"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 2
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
|
||||
|
||||
## File-Based Checkpoint Trigger
|
||||
|
||||
### Configuration
|
||||
|
||||
Enable in your config:
|
||||
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 100 # Optional: check every N steps (default: 100)
|
||||
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `enabled`: `true` to enable (required)
|
||||
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
|
||||
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
|
||||
2. When detected, file is deleted and checkpoint is saved
|
||||
3. In distributed training, rank 0 broadcasts to synchronize all ranks
|
||||
|
||||
### Usage
|
||||
|
||||
**Command line:**
|
||||
```bash
|
||||
touch /path/to/output_dir/axolotl_checkpoint.save
|
||||
```
|
||||
|
||||
**Programmatic:**
|
||||
```python
|
||||
from pathlib import Path
|
||||
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
|
||||
```
|
||||
|
||||
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
|
||||
|
||||
**Custom filename:**
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
trigger_file_path: "my_trigger.save"
|
||||
```
|
||||
```bash
|
||||
touch /path/to/output_dir/my_trigger.save
|
||||
```
|
||||
|
||||
## Control+C (SIGINT) Checkpoint
|
||||
|
||||
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
|
||||
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
|
||||
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
output_dir: ./outputs/lora-out
|
||||
save_steps: 500 # Scheduled checkpoints
|
||||
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 50
|
||||
```
|
||||
|
||||
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).
|
||||
@@ -1,206 +0,0 @@
|
||||
---
|
||||
title: "Which Fine-Tuning Method Should I Use?"
|
||||
description: "A decision guide for choosing the right fine-tuning method, adapter, and hardware configuration in Axolotl."
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 3
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
## Overview {#sec-overview}
|
||||
|
||||
Axolotl supports four broad categories of fine-tuning, each suited to different data types, objectives, and resource constraints.
|
||||
|
||||
| Method | What It Does | Data You Need |
|
||||
|--------|-------------|---------------|
|
||||
| **Supervised Fine-Tuning (SFT)** | Teaches the model to produce specific outputs given inputs | Input-output pairs (instructions, conversations, completions) |
|
||||
| **Preference Learning (DPO/KTO/ORPO)** | Steers the model toward preferred outputs and away from dispreferred ones | Chosen/rejected response pairs (DPO, ORPO) or binary labels (KTO) |
|
||||
| **Reinforcement Learning (GRPO)** | Optimizes the model against a reward signal through online generation | A reward function (code or model-based) and a prompt dataset |
|
||||
| **Reward Modeling** | Trains a model to score responses, for use as a reward signal in RL | Preference pairs ranked by quality |
|
||||
|
||||
Each method is configured through a YAML file with `rl: <method>` (or omitted for SFT). All methods support LoRA, QLoRA, and full fine-tuning unless otherwise noted.
|
||||
|
||||
## Decision Tree {#sec-decision-tree}
|
||||
|
||||
Use the following flowchart to choose your method. Start at the top and follow the path that matches your situation.
|
||||
|
||||
```
|
||||
Do you have a reward function (code-based or model-based)?
|
||||
├── YES
|
||||
│ └── Use GRPO (rl: grpo)
|
||||
│ The model generates its own completions and learns from reward scores.
|
||||
│ Best for: math, code, reasoning, tasks with verifiable answers.
|
||||
│ See: rlhf.qmd#grpo
|
||||
│
|
||||
└── NO
|
||||
│
|
||||
Do you have preference pairs (chosen vs. rejected responses)?
|
||||
├── YES
|
||||
│ │
|
||||
│ Are they paired (same prompt, one chosen, one rejected)?
|
||||
│ ├── YES → Use DPO (rl: dpo)
|
||||
│ │ Direct optimization without a separate reward model.
|
||||
│ │ See: rlhf.qmd#dpo
|
||||
│ │
|
||||
│ └── NO (only binary good/bad labels)
|
||||
│ └── Use KTO (rl: kto)
|
||||
│ Works with unpaired preference data.
|
||||
│ See: rlhf.qmd#kto
|
||||
│
|
||||
└── NO
|
||||
│
|
||||
Do you have input-output examples?
|
||||
├── YES → Use SFT
|
||||
│ The simplest and most common method.
|
||||
│ See: getting-started.qmd
|
||||
│
|
||||
└── NO
|
||||
└── You need to create training data first.
|
||||
Consider generating preference pairs with an LLM judge,
|
||||
or writing a reward function for GRPO.
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
**When in doubt, start with SFT.** It is the most straightforward method and works well for most tasks. You can always move to preference learning or RL later to further refine behavior.
|
||||
:::
|
||||
|
||||
### Method Comparison at a Glance
|
||||
|
||||
| Criterion | SFT | DPO | KTO | GRPO |
|
||||
|-----------|-----|-----|-----|------|
|
||||
| Data complexity | Low (input-output pairs) | Medium (preference pairs) | Medium (binary labels) | Low (prompts + reward code) |
|
||||
| Compute cost | Low | Medium | Medium | High (requires vLLM server) |
|
||||
| Learning signal | Supervised | Contrastive | Contrastive | Online reward |
|
||||
| Online generation | No | No | No | Yes |
|
||||
| Reward model needed | No | No | No | No (uses reward functions) |
|
||||
| Best for | Task adaptation, instruction following | Safety, style alignment | Unpaired preference data | Reasoning, math, code |
|
||||
|
||||
::: {.callout-note}
|
||||
**ORPO** is an alternative to DPO that combines SFT and preference optimization in a single training stage, removing the need for a separate SFT step. Configure with `rl: orpo`. See [rlhf.qmd](rlhf.qmd) for details.
|
||||
:::
|
||||
|
||||
## Adapter Selection {#sec-adapter-selection}
|
||||
|
||||
Once you have chosen a method, decide how to apply the parameter updates. The three main options trade off VRAM usage against model quality.
|
||||
|
||||
### QLoRA
|
||||
|
||||
- **How it works**: The base model is loaded in 4-bit (NF4) quantization. Small low-rank adapter matrices are trained in higher precision on top.
|
||||
- **VRAM savings**: Roughly 4x reduction in model memory compared to full fine-tuning.
|
||||
- **Quality**: Slight degradation due to quantization noise, but often negligible for task-specific fine-tuning.
|
||||
- **When to use**: When your GPU cannot fit the model in full precision, or when you want fast experimentation.
|
||||
|
||||
```yaml
|
||||
adapter: qlora
|
||||
load_in_4bit: true
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
```
|
||||
|
||||
### LoRA
|
||||
|
||||
- **How it works**: The base model is loaded at full precision (or 8-bit). Low-rank adapter matrices are trained alongside.
|
||||
- **VRAM savings**: Roughly 2-3x reduction compared to full fine-tuning (model weights are frozen, only adapters + optimizer states for adapters are stored).
|
||||
- **Quality**: Very close to full fine-tuning for most tasks, especially with higher rank values.
|
||||
- **When to use**: When you have enough VRAM for the base model but not for full optimizer states.
|
||||
|
||||
```yaml
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
For GRPO training, LoRA is strongly recommended. The vLLM server needs to sync weights from the trainer, and LoRA sync (`trl.vllm_lora_sync: true`) is far more efficient than syncing full merged weights. See [vLLM Serving](vllm_serving.qmd) for details.
|
||||
:::
|
||||
|
||||
### Full Fine-Tuning
|
||||
|
||||
- **How it works**: All model parameters are updated during training. No adapters.
|
||||
- **VRAM savings**: None. Requires memory for model weights, gradients, and optimizer states (roughly 4x model size in bf16 with AdamW).
|
||||
- **Quality**: Highest potential quality, especially for large distribution shifts.
|
||||
- **When to use**: When you have ample GPU memory or multi-GPU setups, and need maximum performance. Also required for pre-training.
|
||||
|
||||
```yaml
|
||||
# No adapter or load_in_* lines needed
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 16
|
||||
```
|
||||
|
||||
### Quick Comparison
|
||||
|
||||
| | QLoRA | LoRA | Full |
|
||||
|---|---|---|---|
|
||||
| Trainable params | ~0.1-1% | ~0.1-1% | 100% |
|
||||
| Model memory | ~25% of full | ~50-100% of full | 100% |
|
||||
| Optimizer memory | Tiny (adapters only) | Tiny (adapters only) | 2x model size (AdamW) |
|
||||
| Training speed | Slower (dequantization overhead) | Baseline | Faster per-step (no adapter overhead) |
|
||||
| Inference | Merge or serve with adapter | Merge or serve with adapter | Direct |
|
||||
| Multi-GPU required? | Rarely | For 13B+ models | For 7B+ models |
|
||||
|
||||
## Hardware Mapping {#sec-hardware-mapping}
|
||||
|
||||
The tables below provide approximate GPU memory requirements. Actual usage depends on context length, batch size, and optimizer choice.
|
||||
|
||||
### SFT / Preference Learning
|
||||
|
||||
| Model Size | QLoRA (4-bit) | LoRA (bf16) | Full (bf16 + AdamW) |
|
||||
|------------|--------------|-------------|---------------------|
|
||||
| 1-3B | 6-8 GB | 8-12 GB | 24-32 GB |
|
||||
| 7-8B | 10-14 GB | 16-24 GB | 60-80 GB |
|
||||
| 13-14B | 16-20 GB | 28-40 GB | 120+ GB |
|
||||
| 30-34B | 24-32 GB | 64-80 GB | 2-4x 80 GB |
|
||||
| 70-72B | 40-48 GB | 2x 80 GB | 4-8x 80 GB |
|
||||
|
||||
::: {.callout-important}
|
||||
These estimates assume a short context length (512-2048 tokens) and micro_batch_size of 1-2. Longer sequences and larger batches increase memory significantly due to activations. Use [gradient checkpointing](gradient_checkpointing.qmd) to reduce activation memory at the cost of ~30% slower training.
|
||||
:::
|
||||
|
||||
### GRPO (RL Training)
|
||||
|
||||
GRPO requires additional GPU(s) for the vLLM generation server. Plan for at least two GPUs: one for training, one for vLLM.
|
||||
|
||||
| Model Size | Training GPU (LoRA, bf16) | vLLM GPU | Total GPUs |
|
||||
|------------|--------------------------|----------|------------|
|
||||
| 0.5-3B | 1x 24 GB | 1x 24 GB | 2x 24 GB |
|
||||
| 7-8B | 1x 80 GB | 1x 80 GB | 2x 80 GB |
|
||||
| 13-14B | 1-2x 80 GB | 1-2x 80 GB | 2-4x 80 GB |
|
||||
| 30-72B | 2-4x 80 GB (FSDP/DeepSpeed) | 2-4x 80 GB (tensor parallel) | 4-8x 80 GB |
|
||||
|
||||
::: {.callout-tip}
|
||||
For single-GPU GRPO, use `vllm_mode: colocate` with `vllm_enable_sleep_mode: true`. The vLLM engine shares the GPU and offloads VRAM when not generating. This works for smaller models (up to ~3B on a 24 GB GPU) but is slower than the two-GPU server mode.
|
||||
:::
|
||||
|
||||
### Multi-GPU Threshold
|
||||
|
||||
You need multi-GPU training when:
|
||||
|
||||
- **Full fine-tuning** of models 7B+ (use FSDP or DeepSpeed ZeRO)
|
||||
- **LoRA** of models 30B+ (or 13B+ with long contexts)
|
||||
- **GRPO** almost always (separate vLLM server), unless using colocate mode
|
||||
|
||||
See [Multi-GPU Training](multi-gpu.qmd) for FSDP and DeepSpeed configuration.
|
||||
|
||||
## Quick Links {#sec-quick-links}
|
||||
|
||||
| Method | Config Key | Documentation | Example Config |
|
||||
|--------|-----------|---------------|----------------|
|
||||
| SFT | *(default, no `rl:` key)* | [Getting Started](getting-started.qmd) | `examples/llama-3/lora-1b.yml` |
|
||||
| DPO | `rl: dpo` | [RLHF - DPO](rlhf.qmd#dpo) | See rlhf.qmd |
|
||||
| KTO | `rl: kto` | [RLHF - KTO](rlhf.qmd#kto) | See rlhf.qmd |
|
||||
| ORPO | `rl: orpo` | [RLHF - ORPO](rlhf.qmd#orpo) | See rlhf.qmd |
|
||||
| GRPO | `rl: grpo` | [RLHF - GRPO](rlhf.qmd#grpo), [vLLM Serving](vllm_serving.qmd) | See rlhf.qmd |
|
||||
| Reward Modeling | `rl: reward_trainer` | [Reward Modelling](reward_modelling.qmd) | See reward_modelling.qmd |
|
||||
|
||||
### Related Guides
|
||||
|
||||
- [Configuration Reference](config-reference.qmd) -- Full list of all config options
|
||||
- [Dataset Formats](dataset-formats) -- How to structure your training data
|
||||
- [Optimizations](optimizations.qmd) -- Flash attention, gradient checkpointing, mixed precision
|
||||
- [Multi-GPU Training](multi-gpu.qmd) -- FSDP and DeepSpeed setup
|
||||
- [vLLM Serving](vllm_serving.qmd) -- Setting up vLLM for GRPO training
|
||||
51
docs/cli.qmd
51
docs/cli.qmd
@@ -23,20 +23,6 @@ axolotl <command> [config.yml] [options]
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Launcher Arguments
|
||||
|
||||
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
|
||||
|
||||
```bash
|
||||
# Pass torchrun arguments
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
|
||||
# Pass accelerate arguments
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
|
||||
```
|
||||
|
||||
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
|
||||
|
||||
## Command Reference
|
||||
|
||||
### fetch
|
||||
@@ -94,11 +80,7 @@ axolotl train config.yml \
|
||||
--num-epochs 3
|
||||
|
||||
# Training without accelerate
|
||||
axolotl train config.yml --launcher python
|
||||
|
||||
# Pass launcher-specific arguments using -- separator
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
|
||||
axolotl train config.yml --no-accelerate
|
||||
|
||||
# Resume training from checkpoint
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
@@ -193,9 +175,6 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl evaluate config.yml
|
||||
|
||||
# Evaluation with launcher arguments
|
||||
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
|
||||
```
|
||||
|
||||
### lm-eval
|
||||
@@ -210,8 +189,6 @@ axolotl lm-eval config.yml
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
lm_eval_model: # model to evaluate (local or hf path)
|
||||
|
||||
# List of tasks to evaluate
|
||||
lm_eval_tasks:
|
||||
- arc_challenge
|
||||
@@ -220,28 +197,7 @@ lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
|
||||
|
||||
### delinearize-llama4
|
||||
|
||||
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
|
||||
|
||||
```bash
|
||||
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
```
|
||||
|
||||
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
|
||||
|
||||
### quantize
|
||||
|
||||
Quantizes a model using the quantization configuration specified in your YAML file.
|
||||
|
||||
```bash
|
||||
axolotl quantize config.yml
|
||||
```
|
||||
|
||||
See [Quantization](./quantize.qmd) for more details.
|
||||
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
|
||||
## Legacy CLI Usage
|
||||
|
||||
@@ -310,6 +266,9 @@ axolotl preprocess config.yml --cloud cloud_config.yml
|
||||
# Train on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train without accelerate on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
|
||||
# Run lm-eval on cloud
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
711
docs/config.qmd
Normal file
711
docs/config.qmd
Normal file
@@ -0,0 +1,711 @@
|
||||
---
|
||||
title: Config Reference
|
||||
description: A complete list of all configuration options.
|
||||
---
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# This can also be a relative path to a model on disk
|
||||
base_model: ./llama-7b-hf
|
||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
revision_of_model:
|
||||
# Optional tokenizer configuration path in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
||||
model_type: AutoModelForCausalLM
|
||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Trust remote code for untrusted source
|
||||
trust_remote_code:
|
||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
is_qwen_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
overrides_of_model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# optional overrides the base model loading from_pretrained
|
||||
overrides_of_model_kwargs:
|
||||
# use_cache: False
|
||||
|
||||
# optional overrides to the bnb 4bit quantization configuration
|
||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||
bnb_config_kwargs:
|
||||
# These are default values
|
||||
llm_int8_has_fp16_weight: false
|
||||
bnb_4bit_quant_type: nf4
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# Use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
|
||||
# Use CUDA fp16
|
||||
fp16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true # require >=ampere
|
||||
|
||||
# No AMP (automatic mixed precision)
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
||||
gpu_memory_limit: 20GiB
|
||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||
lora_on_cpu: true
|
||||
|
||||
# List[str]. Add plugins to extend the pipeline.
|
||||
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
|
||||
# https://docs.axolotl.ai/docs/custom_integrations.html
|
||||
plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
|
||||
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
|
||||
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
|
||||
|
||||
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
|
||||
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||
|
||||
# Custom user instruction prompt
|
||||
- path: repo
|
||||
type:
|
||||
# The below are defaults. only set what's needed if you use a different column name.
|
||||
system_prompt: ""
|
||||
system_format: "{system}"
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_input: input
|
||||
field_output: output
|
||||
|
||||
# Customizable to be single line or multi-line
|
||||
# Use {instruction}/{input} as key to be replaced
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
Assistant:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Using chat template
|
||||
- path: ...
|
||||
# Set type to `chat_template` to use this strategy
|
||||
type: chat_template
|
||||
# Specify the name of the chat template to use
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
chat_template: tokenizer_default
|
||||
|
||||
# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
|
||||
chat_template_jinja:
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
|
||||
# Mapping of properties from the input dataset to the chat template.
|
||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||
# If a property exists in the template but not in this mapping, the system will attempt
|
||||
# to load it directly from the message using the property name as the key.
|
||||
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
||||
# while 'value' is loaded and used as 'content' in the chat template.
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
# ...
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages.
|
||||
# The format is {target_role: [source_roles]}. All source roles will be mapped to the target role.
|
||||
# The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
assistant: ["gpt", "assistant"]
|
||||
system: ["system"]
|
||||
tool: ["tool"]
|
||||
|
||||
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
|
||||
# This does not drop the default system message from chat_template if it exists. If you wish to,
|
||||
# we recommend using a custom jinja template with the default system message removed or
|
||||
# adding a system turn with empty content.
|
||||
drop_system_message:
|
||||
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["assistant"] # default
|
||||
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
|
||||
message_field_training_detail: train_detail
|
||||
|
||||
|
||||
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
|
||||
Deduplicates datasets and test_datasets with identical entries.
|
||||
dataset_exact_deduplication: true
|
||||
|
||||
# A list of one or more datasets to eval the model with.
|
||||
# You can use either test_datasets, or val_set_size, but not both.
|
||||
test_datasets:
|
||||
- path: /workspace/data/eval.jsonl
|
||||
ds_type: json
|
||||
# You need to specify a split. For "json" datasets the default split is called "train".
|
||||
split: train
|
||||
type: completion
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||
rl:
|
||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||
|
||||
# dpo
|
||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||
|
||||
# orpo
|
||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||
|
||||
# kto
|
||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||
|
||||
# simpo
|
||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
||||
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
||||
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
||||
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
|
||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||
|
||||
num_generations: # Optional[int]. Number of generations to sample.
|
||||
log_completions: # Optional[bool]. Whether to log completions.
|
||||
|
||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||
|
||||
|
||||
# reward modelling: `True` or `False`
|
||||
reward_model:
|
||||
|
||||
# process reward modelling: `True` or `False`
|
||||
process_reward_model:
|
||||
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
|
||||
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
|
||||
chat_template: tokenizer_default
|
||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||
chat_template_jinja: null
|
||||
# Changes the default system message. Currently only supports chatml.
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# Keep dataset in memory while preprocessing
|
||||
# Only needed if cached dataset is taking too much storage
|
||||
dataset_keep_in_memory:
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # private repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
# Num shards for whole dataset
|
||||
dataset_shard_num:
|
||||
# Index of shard to use for whole dataset
|
||||
dataset_shard_idx:
|
||||
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||
eval_sample_packing:
|
||||
# You can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
# Increasing the following values helps with packing, but usually only slightly (<%1.)
|
||||
# The number of samples packed at a time.
|
||||
sample_packing_group_size: 100000
|
||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||
sample_packing_bin_size: 200
|
||||
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
|
||||
|
||||
# whether to concatenate samples during pretraining
|
||||
pretraining_sample_concatenation:
|
||||
|
||||
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
|
||||
|
||||
# Use batch flattening for speedups when not using sample_packing
|
||||
batch_flattening:
|
||||
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||
max_memory:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `output_dir`.
|
||||
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
|
||||
lora_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
# - o_proj
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear modules
|
||||
|
||||
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
|
||||
peft_layers_to_transform:
|
||||
|
||||
# Optional[bool]. Whether to use DoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
|
||||
peft_use_dora:
|
||||
|
||||
# Optional[bool]. Whether to use RSLoRA.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
|
||||
peft_use_rslora:
|
||||
|
||||
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
|
||||
peft_layer_replication:
|
||||
|
||||
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
|
||||
# How to initialize LoRA weights. Default to True which is MS original implementation.
|
||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
|
||||
peft_init_lora_weights:
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# Apply custom LoRA autograd functions and activation function Triton kernels for
|
||||
# speed and memory savings
|
||||
# See: https://docs.axolotl.ai/docs/lora_optims.html
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
# LoRA+ hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
||||
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
||||
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
||||
|
||||
peft:
|
||||
# Configuration options for loftq initialization for LoRA
|
||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
||||
loftq_config:
|
||||
loftq_bits: # typically 4 bits
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_anneal_steps: # Number of anneal steps for each relora cycle
|
||||
relora_prune_ratio: # threshold for optimizer magnitude when pruning
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_name: # Set the name of your wandb run
|
||||
wandb_run_id: # Set the ID of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# mlflow configuration if you're using it
|
||||
mlflow_tracking_uri: # URI to mlflow
|
||||
mlflow_experiment_name: # Your experiment name
|
||||
mlflow_run_name: # Your run name
|
||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||
|
||||
# Comet configuration if you're using it
|
||||
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
||||
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
||||
use_comet: # Enable or disable Comet integration.
|
||||
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
||||
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
||||
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
||||
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
||||
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Tensorboard
|
||||
use_tensorboard: # Optional[bool]
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
# setting to `auto` will enable torch compile when torch>=2.5.1
|
||||
torch_compile: # Optional[Union[Literal["auto"], bool]]
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
warmup_steps: 100 # cannot use with warmup_ratio
|
||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
|
||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
|
||||
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
|
||||
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
max_steps:
|
||||
|
||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
||||
include_tokens_per_second: # Optional[bool]
|
||||
|
||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
||||
auto_find_batch_size: # Optional[bool]
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
|
||||
# snapshots can be visualized @ https://pytorch.org/memory_viz
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# Group similarly sized data to minimize padding.
|
||||
# May be slower to start, as it must download and sort the entire dataset.
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
# gradient_checkpointing_kwargs:
|
||||
# use_reentrant: true
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
|
||||
# Specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
||||
#
|
||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||
# in the examples/ for your model and fine-tuning use case.
|
||||
#
|
||||
# Valid values for 'optimizer' include:
|
||||
# - adamw_torch
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_torch_npu_fused
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - adamw_torch_4bit
|
||||
# - ademamix
|
||||
# - sgd
|
||||
# - adagrad
|
||||
# - adamw_bnb_8bit
|
||||
# - adamw_8bit # alias for adamw_bnb_8bit
|
||||
# - ademamix_8bit
|
||||
# - lion_8bit
|
||||
# - lion_32bit
|
||||
# - paged_adamw_32bit
|
||||
# - paged_adamw_8bit
|
||||
# - paged_ademamix_32bit
|
||||
# - paged_ademamix_8bit
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
# - rmsprop
|
||||
# - rmsprop_bnb
|
||||
# - rmsprop_bnb_8bit
|
||||
# - rmsprop_bnb_32bit
|
||||
# - galore_adamw
|
||||
# - galore_adamw_8bit
|
||||
# - galore_adafactor
|
||||
# - galore_adamw_layerwise
|
||||
# - galore_adamw_8bit_layerwise
|
||||
# - galore_adafactor_layerwise
|
||||
# - lomo
|
||||
# - adalomo
|
||||
# - grokadamw
|
||||
# - schedule_free_adamw
|
||||
# - schedule_free_sgd
|
||||
# - apollo_adamw
|
||||
# - apollo_adamw_layerwise
|
||||
#
|
||||
# Additional custom optimizers include:
|
||||
# - optimi_adamw
|
||||
# - ao_adamw_8bit
|
||||
# - ao_adamw_fp8
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
# For Galore Optimizers the following optim_args are available
|
||||
# rank: # type: int
|
||||
# update_proj_gap # type: int
|
||||
# scale # type: float
|
||||
# proj_type: # type: str, default = std
|
||||
|
||||
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
||||
optim_target_modules:
|
||||
# - self_attn # for llama
|
||||
# - mlp
|
||||
|
||||
# Specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
neftune_noise_alpha:
|
||||
|
||||
# Optional[bool]. Whether to bettertransformers
|
||||
flash_optimum:
|
||||
|
||||
# Note: Only one of the following attention patches can be used at a time.
|
||||
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
|
||||
|
||||
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
|
||||
# Optional[bool]. Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
s2_attention:
|
||||
|
||||
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||
low_cpu_mem_usage:
|
||||
# Optional[str]. Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
## Multimodal section
|
||||
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
||||
# Will read from model/processor config if not set.
|
||||
image_size:
|
||||
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
||||
image_resize_algorithm: 'bilinear'
|
||||
## End of multimodal section
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
# Add or change special tokens.
|
||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
special_tokens:
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
# pad_token: "[PAD]"
|
||||
|
||||
# Add extra tokens.
|
||||
tokens:
|
||||
|
||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||
added_tokens_overrides: # Dict[int, str]
|
||||
# 128041: "<|im_start|>"
|
||||
# 128042: "<|im_end|>"
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
ddp_timeout:
|
||||
ddp_bucket_cap_mb:
|
||||
ddp_broadcast_buffers:
|
||||
|
||||
# Sequence parallelism
|
||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
|
||||
sequence_parallel_degree:
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
# Must evenly divide the number of KV heads in your model.
|
||||
heads_k_stride: 1
|
||||
|
||||
# Path to torch distx for optim 'adamw_anyprecision'
|
||||
torchdistx_path:
|
||||
|
||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
||||
pretraining_dataset:
|
||||
|
||||
# Debug mode
|
||||
debug:
|
||||
|
||||
# Seed
|
||||
seed:
|
||||
|
||||
# Allow overwrite yml config using from cli
|
||||
strict:
|
||||
```
|
||||
@@ -7,7 +7,6 @@ toc-depth: 3
|
||||
```{python}
|
||||
#| echo: false
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
def process_readme(integration_name):
|
||||
@@ -50,28 +49,9 @@ sections = [
|
||||
("Knowledge Distillation (KD)", "kd"),
|
||||
("Liger Kernels", "liger"),
|
||||
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
|
||||
("Spectrum", "spectrum"),
|
||||
("LLMCompressor", "llm_compressor")
|
||||
("Spectrum", "spectrum")
|
||||
]
|
||||
|
||||
for folder_name in os.listdir("../src/axolotl/integrations/"):
|
||||
if folder_name in [path for name, path in sections]:
|
||||
# skip if already in sections
|
||||
continue
|
||||
if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"):
|
||||
# grab the first heading in README.md as the section name
|
||||
with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f:
|
||||
txt = f.read()
|
||||
matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE)
|
||||
if matches:
|
||||
name = matches.group(1)
|
||||
else:
|
||||
continue
|
||||
sections.append((name, folder_name))
|
||||
|
||||
# sort sections by name
|
||||
sections = sorted(sections, key=lambda x: x[0])
|
||||
|
||||
for section_name, folder_name in sections:
|
||||
print(print_section(section_name, folder_name))
|
||||
```
|
||||
|
||||
@@ -4,15 +4,27 @@ description: Conversation format for supervised fine-tuning.
|
||||
order: 3
|
||||
---
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
:::
|
||||
|
||||
## pygmalion
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## chat_template
|
||||
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
```
|
||||
|
||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||
See [configs](../config.qmd) for full configs and supported templates.
|
||||
|
||||
### Migrating from sharegpt
|
||||
|
||||
@@ -52,9 +64,7 @@ We recommend checking the below examples for other usecases.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Training on last message
|
||||
|
||||
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -68,9 +78,7 @@ datasets:
|
||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||
:::
|
||||
|
||||
#### Overriding default chat template
|
||||
|
||||
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
@@ -80,13 +88,7 @@ datasets:
|
||||
roles_to_train: ["assistant"] # default value
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
|
||||
:::
|
||||
|
||||
#### Using default chat template with fallback
|
||||
|
||||
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
||||
@@ -95,9 +97,7 @@ datasets:
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
#### Custom Jinja template
|
||||
|
||||
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
||||
@@ -109,141 +109,10 @@ datasets:
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
:::
|
||||
|
||||
#### Using template with different token for EOT and EOS
|
||||
|
||||
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
- "[/INST]"
|
||||
# - "[/SYSTEM_PROMPT]"
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
# optional
|
||||
train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
See [config documentation](../config-reference.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
|
||||
|
||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details.
|
||||
:::
|
||||
|
||||
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
- "[/INST]"
|
||||
# ...
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
train_on_eos: last
|
||||
train_on_eot: turn
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.
|
||||
:::
|
||||
|
||||
|
||||
#### Using tool use
|
||||
|
||||
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "...",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"description": "...",
|
||||
"parameters": {
|
||||
"type": "...",
|
||||
"properties": {
|
||||
// ...
|
||||
},
|
||||
"required": ["..."],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
// ...
|
||||
{
|
||||
"role": "assistant", // call the function via assistant
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "...", // required only for mistral
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"arguments": {
|
||||
"...": "...",
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "...", // required only for mistral
|
||||
"name": "...",
|
||||
"content": "..."
|
||||
},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
|
||||
|
||||
```
|
||||
"arguments": "{\"...\": \"...\"}"
|
||||
```
|
||||
|
||||
The same is applicable for tool parameters.
|
||||
|
||||
```
|
||||
"parameters": "{\"...\": \"...\"}"
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: Nanobit/text-tools-2k-test
|
||||
type: chat_template
|
||||
# field_tools: tools # default is `tools`
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
|
||||
:::
|
||||
|
||||
|
||||
#### Using fine-grained control over token masking
|
||||
|
||||
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
|
||||
@@ -293,45 +162,3 @@ datasets:
|
||||
::: {.callout-tip}
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
#### Reasoning split
|
||||
|
||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: qwen3
|
||||
split_thinking: true
|
||||
```
|
||||
|
||||
For example, a content can look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "<think>Some thinking outputs</think>Output after thinking."
|
||||
}
|
||||
```
|
||||
|
||||
After split, it will look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"reasoning_content": "Some thinking outputs",
|
||||
"content": "Output after thinking..."
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section.
|
||||
:::
|
||||
|
||||
## pygmalion
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
@@ -22,46 +22,90 @@ For `pretraining_dataset:` specifically, please refer to the [Pre-training secti
|
||||
|
||||
## Pre-training
|
||||
|
||||
Pre-training trains on raw text corpora with no input masking. The dataset format is simple:
|
||||
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
|
||||
|
||||
A sample format for a pre-training dataset is as follows:
|
||||
|
||||
```json
|
||||
{"text": "first row"}
|
||||
{"text": "second row"}
|
||||
...
|
||||
```
|
||||
|
||||
Axolotl supports two approaches:
|
||||
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
|
||||
|
||||
### Streaming (large datasets)
|
||||
Axolotl supports loading from a Hugging Face hub repo or from local files.
|
||||
|
||||
For large corpora that don't fit in memory, use `pretraining_dataset` with [streaming](../streaming.qmd). Data is tokenized on-demand during training.
|
||||
::: {.callout-important}
|
||||
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
|
||||
:::
|
||||
|
||||
### Pre-training from Hugging Face hub datasets
|
||||
|
||||
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset: hf_org/name
|
||||
```
|
||||
|
||||
### Pre-training from local dataset files
|
||||
|
||||
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
|
||||
|
||||
```yaml
|
||||
pretraining_dataset:
|
||||
- path: HuggingFaceFW/fineweb-edu
|
||||
type: pretrain
|
||||
text_column: text
|
||||
split: train
|
||||
- path: json
|
||||
data_files:
|
||||
- A.jsonl
|
||||
- B.jsonl
|
||||
- C.jsonl
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Streaming requires `max_steps` in your config — Axolotl cannot infer the dataset size. One step = `sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus` tokens.
|
||||
:::
|
||||
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
|
||||
|
||||
See [Streaming Datasets](../streaming.qmd) for full configuration details.
|
||||
### Pre-training without streaming
|
||||
|
||||
### Non-streaming (smaller datasets)
|
||||
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
|
||||
|
||||
For datasets that fit in memory, use `type: completion` under `datasets:`. The entire dataset is pre-tokenized before training, which can be done on a CPU-only machine.
|
||||
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
|
||||
|
||||
From Hugging Face:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: my_corpus
|
||||
- path: hf_org/name
|
||||
type: completion
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
With `completion`, texts exceeding `sequence_len` are split into multiple samples automatically.
|
||||
:::
|
||||
From local files (either example works):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: completion
|
||||
|
||||
- path: json
|
||||
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
|
||||
type: completion
|
||||
```
|
||||
|
||||
### Pre-training dataset configuration tips
|
||||
|
||||
#### Setting max_steps
|
||||
|
||||
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
|
||||
|
||||
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
|
||||
|
||||
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
|
||||
|
||||
#### Group_by_length
|
||||
|
||||
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
|
||||
|
||||
### Reference
|
||||
|
||||
Please see docs [here](pretraining.qmd).
|
||||
|
||||
## Supervised fine-tuning (SFT)
|
||||
|
||||
|
||||
@@ -186,4 +186,4 @@ datasets:
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
```
|
||||
|
||||
See full config options under [here](../config-reference.qmd).
|
||||
See full config options under [here](../config.qmd).
|
||||
|
||||
@@ -4,9 +4,29 @@ description: Data format for a pre-training completion task.
|
||||
order: 1
|
||||
---
|
||||
|
||||
::: {.callout-note}
|
||||
Pre-training documentation has been consolidated:
|
||||
For pretraining, there is no prompt template or roles. The only required field is `text`:
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"text": "first row"}
|
||||
{"text": "second row"}
|
||||
...
|
||||
```
|
||||
|
||||
:::{.callout-note}
|
||||
|
||||
### Streaming is recommended for large datasets
|
||||
|
||||
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
|
||||
|
||||
```{.yaml filename="config.yaml"}
|
||||
pretraining_dataset:
|
||||
- name:
|
||||
path:
|
||||
split:
|
||||
text_column: # column in dataset with the data, usually `text`
|
||||
type: pretrain
|
||||
trust_remote_code:
|
||||
skip: # number of rows of data to skip over from the beginning
|
||||
```
|
||||
|
||||
- **Streaming pretraining** (large datasets): See [Streaming Datasets](../streaming.qmd#pretraining-with-streaming)
|
||||
- **Non-streaming pretraining** (`type: completion`): See [Dataset Formats](index.qmd#pre-training)
|
||||
:::
|
||||
|
||||
@@ -36,7 +36,7 @@ This matches the API of [`datasets.load_dataset`](https://github.com/huggingface
|
||||
|
||||
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
|
||||
|
||||
For full details on the config, see [config-reference.qmd](config-reference.qmd).
|
||||
For full details on the config, see [config.qmd](config.qmd).
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
@@ -54,7 +54,7 @@ datasets:
|
||||
|
||||
#### Files
|
||||
|
||||
To load a JSON file, you would do something like this:
|
||||
Usually, to load a JSON file, you would do something like this:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
@@ -66,11 +66,19 @@ Which translates to the following config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: data.json
|
||||
ds_type: json
|
||||
- path: json
|
||||
data_files: /path/to/your/file.jsonl
|
||||
```
|
||||
|
||||
In the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.
|
||||
However, to make things easier, we have added a few shortcuts for loading local dataset files.
|
||||
|
||||
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: /path/to/your/file.jsonl
|
||||
ds_type: json
|
||||
```
|
||||
|
||||
This works for CSV, JSON, Parquet, and Arrow files.
|
||||
|
||||
|
||||
@@ -6,10 +6,6 @@ description: How to debug Axolotl
|
||||
|
||||
This document provides some tips and tricks for debugging Axolotl. It also provides an example configuration for debugging with VSCode. A good debugging setup is essential to understanding how Axolotl code works behind the scenes.
|
||||
|
||||
::: {.callout-tip}
|
||||
For training-specific debugging (loss spikes, NaN gradients, OOM errors, RL training stability), see [Training Stability & Debugging](training_stability.qmd).
|
||||
:::
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [General Tips](#general-tips)
|
||||
@@ -33,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
|
||||
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
|
||||
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
|
||||
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
|
||||
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
|
||||
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
|
||||
|
||||
```yaml
|
||||
@@ -89,7 +85,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
|
||||
|
||||
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 axolotl train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```json
|
||||
// .vscode/launch.json
|
||||
@@ -105,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 axolotl
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_num_proc=1", // limits data preprocessing to one process
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
"--max_steps=1", // limits training to just one step
|
||||
"--batch_size=1", // minimizes batch size
|
||||
"--micro_batch_size=1", // minimizes batch size
|
||||
@@ -246,6 +242,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
</div>
|
||||
<br>
|
||||
|
||||
[^1]: The VSCode config uses `accelerate.commands.launch` as the Python module entry point, which is what `axolotl train` invokes under the hood.
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
|
||||
|
||||
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
||||
|
||||
@@ -8,10 +8,6 @@ format:
|
||||
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
## Base
|
||||
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
@@ -32,8 +28,9 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.8.0`
|
||||
- `main-base-py3.11-cu128-2.9.1`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.5.1`
|
||||
- `main-base-py3.11-cu124-2.4.1`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -53,7 +50,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||
# on push to main
|
||||
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
|
||||
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
|
||||
main-latest
|
||||
|
||||
# nightly build
|
||||
@@ -71,12 +68,14 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.8.0`
|
||||
- `main-py3.11-cu128-2.9.1`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu124-2.5.1`
|
||||
- `main-py3.11-cu124-2.4.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `0.12.0`
|
||||
- `main-20250303-py3.11-cu124-2.5.1`
|
||||
- `main-20250303-py3.11-cu124-2.4.1`
|
||||
- `0.7.1`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
556
docs/ebft.qmd
556
docs/ebft.qmd
@@ -1,556 +0,0 @@
|
||||
---
|
||||
title: "EBFT Training"
|
||||
description: "Energy-Based Fine-Tuning uses feature-matching rewards from internal representations to train language models without external reward functions."
|
||||
order: 9
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the **internal feature representations** of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.
|
||||
|
||||
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
|
||||
|
||||
### How EBFT Differs from Other RL Methods
|
||||
|
||||
| Method | Reward Signal | Requires | Best For |
|
||||
|--------|--------------|----------|----------|
|
||||
| **GRPO** | External reward function(s) | Custom reward code or reward model | Tasks with verifiable answers (math, code) |
|
||||
| **DPO** | Preference pairs (chosen vs rejected) | Paired preference data | Alignment with human preferences |
|
||||
| **EBFT** | Feature similarity to ground truth | Ground-truth completions | Any task with reference outputs |
|
||||
|
||||
EBFT's key advantage is that it needs only ground-truth completions -- no reward engineering, no preference annotation, and no reward model training. The model's own internal representations serve as the reward signal. This makes it particularly effective for:
|
||||
|
||||
- Code generation (match features of known-good solutions)
|
||||
- Instruction following with reference outputs
|
||||
- Continual pretraining on unstructured text (strided mode)
|
||||
- Multi-turn dialogue with reference conversations
|
||||
|
||||
### Reward Formulation
|
||||
|
||||
The EBFT reward for each generated completion is:
|
||||
|
||||
```
|
||||
reward = alignment_coef * cosine_similarity(gen_features, gt_features)
|
||||
- diversity_coef * mean_pairwise_similarity(gen_features)
|
||||
```
|
||||
|
||||
- **Alignment**: How closely the generated output's internal representations match the ground truth. Higher is better.
|
||||
- **Diversity**: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.
|
||||
- **CFM loss** (Cross-Feature Matching): Tracks `||mean(gen_features) - gt_features||^2` as a diagnostic. This is the quantity that EBFT ultimately minimizes.
|
||||
|
||||
## Modes
|
||||
|
||||
EBFT supports three operational modes, each suited to different use cases.
|
||||
|
||||
### Structured Mode (Sync)
|
||||
|
||||
Uses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.
|
||||
|
||||
```
|
||||
GPU 0: vLLM Server (generates completions, receives weight syncs)
|
||||
GPU 1: Trainer (feature extraction, reward computation, GRPO training)
|
||||
```
|
||||
|
||||
**When to use**: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.
|
||||
|
||||
### Structured Mode (Async)
|
||||
|
||||
Same architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.
|
||||
|
||||
**When to use**: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by `vllm_sync_interval`).
|
||||
|
||||
### Strided Mode
|
||||
|
||||
Runs entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.
|
||||
|
||||
```
|
||||
Single GPU: Base model + LoRA adapter
|
||||
- Strided block-parallel generation (flex_attention)
|
||||
- Feature extraction via disable_adapter()
|
||||
- No vLLM needed
|
||||
```
|
||||
|
||||
**When to use**: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Structured Mode
|
||||
|
||||
This minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.
|
||||
|
||||
**Step 1**: Create a config file `ebft_quickstart.yaml`:
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2-0.5B-Instruct
|
||||
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
|
||||
trl:
|
||||
num_generations: 4
|
||||
max_completion_length: 256
|
||||
temperature: 0.7
|
||||
use_vllm: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_lora_sync: true
|
||||
vllm_sync_interval: 3
|
||||
use_data_producer: true
|
||||
async_prefetch: false
|
||||
scale_rewards: true
|
||||
loss_type: grpo
|
||||
|
||||
vllm:
|
||||
gpu_memory_utilization: 0.5
|
||||
max_model_len: 1024
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
|
||||
# Standard training settings (see getting-started.qmd for details)
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
sequence_len: 1024
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
max_steps: 20
|
||||
learning_rate: 5.0e-6
|
||||
bf16: auto
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
output_dir: ./outputs/ebft-quickstart
|
||||
```
|
||||
|
||||
**Step 2**: Start vLLM on GPU 0:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve ebft_quickstart.yaml
|
||||
```
|
||||
|
||||
**Step 3**: Wait approximately 30 seconds for vLLM to initialize, then start training on GPU 1:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train ebft_quickstart.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
The `micro_batch_size` must be divisible by `num_generations`. For example, with `num_generations: 4`, valid values are 4, 8, 12, etc.
|
||||
:::
|
||||
|
||||
### Dataset Format
|
||||
|
||||
Structured mode datasets must produce two fields after the transform:
|
||||
|
||||
- `prompt`: Either a string or a list of chat messages (`[{"role": "user", "content": "..."}]`)
|
||||
- `ground_truth`: A string containing the reference completion
|
||||
|
||||
Example raw dataset row:
|
||||
|
||||
```json
|
||||
{
|
||||
"input": "Write a function to compute fibonacci numbers.",
|
||||
"output": "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"
|
||||
}
|
||||
```
|
||||
|
||||
The `ebft_opencode.transform` converts this to the required `{prompt, ground_truth}` format automatically.
|
||||
|
||||
## Feature Extraction
|
||||
|
||||
EBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.
|
||||
|
||||
### Feature Layers
|
||||
|
||||
The `feature_layers` parameter specifies which layers to extract, as fractions of total model depth:
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
feature_layers: [0.25, 0.5, 0.75] # Quarter, middle, three-quarter depth
|
||||
```
|
||||
|
||||
For a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size `num_layers * hidden_dim`.
|
||||
|
||||
::: {.callout-tip}
|
||||
Using multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default `[0.25, 0.5, 0.75]` works well across model sizes.
|
||||
:::
|
||||
|
||||
### Embed Methods
|
||||
|
||||
The `embed_method` controls how per-token hidden states are pooled into a single vector per sequence:
|
||||
|
||||
| Method | Description | Output Shape | Notes |
|
||||
|--------|-------------|-------------|-------|
|
||||
| `last_token` | Hidden state at the last non-padding token | `(B, D)` | Default. Good for autoregressive models where the last token summarizes the sequence. |
|
||||
| `mean_pooling` | Mean of all non-padding token states | `(B, D)` | Considers the entire sequence equally. |
|
||||
| `completion_mean` | Mean over completion tokens only (excludes prompt) | `(B, D)` | Focuses reward signal on generated content. Requires prompt length information. |
|
||||
| `concat` | Concatenation of states at 25%, 50%, 75% positions | `(B, 3*D)` | Captures positional structure. Higher dimensional. |
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
embed_method: completion_mean # Focus on completion features
|
||||
```
|
||||
|
||||
### SVD Whitening
|
||||
|
||||
Whitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
use_whitening: true
|
||||
```
|
||||
|
||||
When whitening is enabled, the reward computation applies a whitening matrix `W = U @ diag(1/S) @ U^T` derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.
|
||||
|
||||
::: {.callout-note}
|
||||
Singular values scale with `sqrt(batch_size)`, so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (`n_samples_per_prompt` or `num_generations`) is fixed during training.
|
||||
:::
|
||||
|
||||
### Alignment and Diversity Coefficients
|
||||
|
||||
The two reward components are weighted by coefficients:
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
alignment_coef: 1.0 # Weight for cosine similarity with ground truth
|
||||
diversity_coef: 1.0 # Weight for pairwise similarity penalty
|
||||
```
|
||||
|
||||
Both values are scaled by 2 internally (per paper equation 7). The final reward per sample is:
|
||||
|
||||
```
|
||||
reward_j = 2 * alignment_coef * cos(gen_j, gt)
|
||||
- 2 * diversity_coef * (1/(n-1)) * sum_{j' != j} dot(gen_j, gen_j')
|
||||
```
|
||||
|
||||
Setting `diversity_coef: 0.0` disables the diversity penalty entirely, which may be appropriate when `num_generations` is small (e.g., 2).
|
||||
|
||||
## Strided Mode
|
||||
|
||||
Strided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places **anchor points** at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.
|
||||
|
||||
### How Block-Parallel Generation Works
|
||||
|
||||
Given a document of length `S` tokens:
|
||||
|
||||
1. **Anchor placement**: Starting at position `anchor_offset`, place anchors every `stride` tokens. Each anchor defines a block.
|
||||
2. **Context window**: Each block sees `context_length` tokens of preceding context from the original document.
|
||||
3. **Generation**: At each anchor, generate `generate_max_len` tokens autoregressively, conditioned only on the context window.
|
||||
4. **Parallelism**: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.
|
||||
|
||||
```
|
||||
Document: [tok0, tok1, ..., tok_S]
|
||||
| | |
|
||||
anchor_0 anchor_1 anchor_2
|
||||
| | |
|
||||
[ctx][gen] [ctx][gen] [ctx][gen]
|
||||
```
|
||||
|
||||
The attention mask ensures:
|
||||
|
||||
- Prompt tokens use standard causal attention
|
||||
- Each generated block attends to its own context window and its own preceding generated tokens
|
||||
- Blocks do not attend to each other's generated tokens
|
||||
|
||||
When `flex_attention` is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.
|
||||
|
||||
### Strided Mode Configuration
|
||||
|
||||
```yaml
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
rl: ebft
|
||||
|
||||
ebft:
|
||||
mode: strided
|
||||
stride: 8 # Tokens between anchor points
|
||||
context_length: 8 # Context window per block
|
||||
generate_max_len: 8 # Tokens to generate per block
|
||||
n_samples_per_prompt: 4 # Independent rollouts per document
|
||||
temperature: 0.6
|
||||
feature_layers: [0.25, 0.5, 0.75]
|
||||
embed_method: last_token
|
||||
use_whitening: true
|
||||
alignment_coef: 1.0
|
||||
diversity_coef: 1.0
|
||||
rl_coef: 1.0 # RL policy gradient loss weight
|
||||
ce_coef: 0.03 # Cross-entropy loss on GT tokens
|
||||
advantage_estimator: rloo # rloo, group_norm, or reinforce
|
||||
min_completion_prefix: 8 # Skip anchors in prompt region
|
||||
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
|
||||
sequence_len: 2048
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
|
||||
adapter: lora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_linear: true
|
||||
|
||||
bf16: auto
|
||||
flex_attention: true
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required with flex_attention
|
||||
```
|
||||
|
||||
Run with a single command (no vLLM needed):
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
|
||||
```
|
||||
|
||||
### Advantage Estimators
|
||||
|
||||
Strided mode supports three advantage estimation methods:
|
||||
|
||||
| Estimator | Formula | Requirements |
|
||||
|-----------|---------|-------------|
|
||||
| `rloo` | Leave-one-out baseline: `reward_j - mean(rewards_{-j})` | `n_samples_per_prompt >= 2` |
|
||||
| `group_norm` | Group normalization: `(reward_j - mean) / std` | `n_samples_per_prompt >= 2` |
|
||||
| `reinforce` | Raw reward as advantage (no baseline) | Works with `n_samples_per_prompt = 1` |
|
||||
|
||||
::: {.callout-warning}
|
||||
When `n_samples_per_prompt: 1`, the trainer automatically falls back to `reinforce` and disables the diversity penalty (which requires multiple samples).
|
||||
:::
|
||||
|
||||
### Strided Mode Constraints
|
||||
|
||||
- **`flex_attention: true`** is strongly recommended. Without it, dense 4D masks consume significantly more memory.
|
||||
- **`torch_compile: true`** must NOT be set. `flex_attention` compiles its own kernels internally; adding `torch_compile` causes conflicts and OOM.
|
||||
- **Gradient checkpointing** must use `use_reentrant: true`. Non-reentrant checkpointing causes `CheckpointError` with `flex_attention` block masks.
|
||||
- **`activation_offloading`** is incompatible with `flex_attention`.
|
||||
|
||||
### Cross-Entropy Loss
|
||||
|
||||
Strided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:
|
||||
|
||||
```yaml
|
||||
ebft:
|
||||
ce_coef: 0.03 # Small CE coefficient
|
||||
rl_coef: 1.0 # RL loss coefficient
|
||||
```
|
||||
|
||||
The total loss is `rl_coef * rl_loss + ce_coef * ce_loss`. For structured mode, `ce_coef` is typically `0.0` since vLLM generation provides sufficient learning signal.
|
||||
|
||||
## Dataset Formats
|
||||
|
||||
EBFT provides several built-in dataset transforms in `src/axolotl/prompt_strategies/ebft/`.
|
||||
|
||||
### Built-In Transforms
|
||||
|
||||
| Transform | Input Format | Output Fields | Use Case |
|
||||
|-----------|-------------|---------------|----------|
|
||||
| `ebft_opencode.transform` | `{input, output}` | `{prompt, ground_truth}` | OpenCodeInstruct, structured QA |
|
||||
| `ebft_strided_structured.transform` | `{input, output}` | `{input_ids, labels, prompt_length}` | Strided mode with structured data |
|
||||
| `ebft_strided_chat.transform` | `{messages: [...]}` | `{input_ids, labels, prompt_length}` | Strided mode with chat data |
|
||||
| `ebft_chat_multiturn.transform` | `{messages: [...]}` | `{prompt, ground_truth, remaining_turns}` | Multi-turn: first-turn target |
|
||||
| `ebft_chat_multiturn.transform_last_turn` | `{messages: [...]}` | `{prompt, ground_truth}` | Multi-turn: last-turn target |
|
||||
| `ebft_chat_multiturn.transform_all_turns` | `{messages: [...]}` | `{prompt[], ground_truth[]}` | Multi-turn: one example per turn |
|
||||
| `ebft_reasoning.transform` | `{messages: [...]}` (with `<think>`) | `{prompt, ground_truth}` | Reasoning/thinking datasets |
|
||||
|
||||
### Structured Mode Datasets
|
||||
|
||||
For structured (sync/async) mode, the transform must produce `prompt` and `ground_truth` fields:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_opencode.transform
|
||||
split: train[:500]
|
||||
```
|
||||
|
||||
### Multi-Turn Datasets
|
||||
|
||||
Multi-turn transforms extract conversation data for sequential rollout. The `transform` variant targets the first assistant turn, while `transform_last_turn` targets the final turn:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/multiturn-dataset
|
||||
type: ebft_chat_multiturn.transform
|
||||
```
|
||||
|
||||
When `remaining_turns` is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.
|
||||
|
||||
### Strided Mode Datasets
|
||||
|
||||
Strided transforms tokenize the full document and produce `input_ids`, `labels`, and `prompt_length`:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: nvidia/OpenCodeInstruct
|
||||
type: ebft_strided_structured.transform
|
||||
split: train[:1%]
|
||||
```
|
||||
|
||||
### Custom Transforms
|
||||
|
||||
To use your own dataset format, write a transform function:
|
||||
|
||||
```python
|
||||
def transform(cfg, **kwargs):
|
||||
def transform_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [{"role": "user", "content": example["question"]}],
|
||||
"ground_truth": example["answer"],
|
||||
}
|
||||
return transform_fn, {"remove_columns": "__all__"}
|
||||
```
|
||||
|
||||
The `"__all__"` sentinel removes all original dataset columns after the mapping step. Reference this transform in your config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/dataset
|
||||
type: your_module.transform
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Common Parameters (All Modes)
|
||||
|
||||
These parameters are set under the `ebft:` key in the YAML config.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `mode` | `"structured"` or `"strided"` | `"structured"` | EBFT operating mode |
|
||||
| `feature_layers` | `list[float]` | `[0.25, 0.5, 0.75]` | Fractional layer depths for feature extraction |
|
||||
| `embed_method` | `string` | `"last_token"` | Pooling method: `last_token`, `mean_pooling`, `completion_mean`, or `concat` |
|
||||
| `use_whitening` | `bool` | `false` | Apply SVD whitening to feature embeddings before reward computation |
|
||||
| `alignment_coef` | `float` | `1.0` | Weight for alignment reward (cosine similarity with ground truth) |
|
||||
| `diversity_coef` | `float` | `1.0` | Weight for diversity penalty (pairwise dot product between samples) |
|
||||
| `ce_coef` | `float` | `0.0` | Cross-entropy loss coefficient on ground-truth tokens |
|
||||
| `adaptive_max_tokens` | `bool` | `true` | Dynamically set vLLM `max_tokens` based on ground-truth length (structured mode) |
|
||||
| `gt_length_multiplier` | `float` | `1.5` | Multiplier for ground-truth token count when computing adaptive max tokens (min 0.1) |
|
||||
|
||||
### Strided Mode Parameters
|
||||
|
||||
These additional parameters apply only when `mode: strided`.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `stride` | `int` | `8` | Number of tokens between anchor points (must be >= 1) |
|
||||
| `context_length` | `int` | `8` | Context window size for each generated block (must be >= 1) |
|
||||
| `generate_max_len` | `int` | `8` | Number of tokens to generate per block (must be >= 1) |
|
||||
| `n_samples_per_prompt` | `int` | `4` | Number of independent rollouts per document (must be >= 1) |
|
||||
| `temperature` | `float` | `0.6` | Sampling temperature for strided generation |
|
||||
| `top_p` | `float` | `1.0` | Top-p nucleus sampling threshold |
|
||||
| `rl_coef` | `float` | `1.0` | RL policy gradient loss coefficient |
|
||||
| `advantage_estimator` | `string` | `"rloo"` | Advantage estimation method: `rloo`, `group_norm`, or `reinforce` |
|
||||
| `min_completion_prefix` | `int` | `0` | Minimum tokens into the completion span before placing anchors |
|
||||
|
||||
### Structured Mode TRL Parameters
|
||||
|
||||
These are set under the `trl:` key and control the GRPO training loop.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `num_generations` | `int` | -- | Number of completions generated per prompt |
|
||||
| `max_completion_length` | `int` | -- | Maximum tokens per generated completion |
|
||||
| `temperature` | `float` | `0.7` | Sampling temperature for vLLM generation |
|
||||
| `use_vllm` | `bool` | -- | Enable vLLM generation backend |
|
||||
| `vllm_lora_sync` | `bool` | `false` | Sync LoRA adapters via filesystem (recommended) |
|
||||
| `vllm_sync_interval` | `int` | `1` | Steps between weight syncs to vLLM |
|
||||
| `use_data_producer` | `bool` | -- | Required for sync mode with LoRA sync |
|
||||
| `async_prefetch` | `bool` | `false` | Enable async generation (overlaps with training) |
|
||||
| `streaming_partial_batch` | `bool` | `false` | Score groups incrementally (async mode) |
|
||||
| `skip_zero_advantage_batches` | `bool` | `false` | Skip micro-batches where all advantages are zero |
|
||||
| `scale_rewards` | `bool` | -- | Normalize rewards within each prompt group |
|
||||
| `loss_type` | `string` | `"grpo"` | Loss type for policy optimization |
|
||||
| `epsilon` | `float` | `0.2` | Clipping parameter for importance sampling |
|
||||
|
||||
### Stop Tokens
|
||||
|
||||
vLLM needs explicit stop token IDs for generation. Common configurations:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
generation_kwargs:
|
||||
stop_token_ids: [151645, 151643] # Qwen: <|im_end|>, <|endoftext|>
|
||||
```
|
||||
|
||||
### Multi-Turn Chat Settings
|
||||
|
||||
For multi-turn conversations with Qwen3.5, disable thinking mode to prevent `<think>` tags in completions:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
chat_template_kwargs:
|
||||
enable_thinking: false
|
||||
```
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Key Metrics
|
||||
|
||||
EBFT logs several custom metrics to wandb and the training console. Here is what to watch for:
|
||||
|
||||
| Metric | Healthy Range | Interpretation |
|
||||
|--------|--------------|----------------|
|
||||
| `ebft/alignment` | 0.3 -- 0.9, trending upward | Cosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference. |
|
||||
| `ebft/diversity` | 0.01 -- 0.1 | Mean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse. |
|
||||
| `ebft/cfm_loss` | Below 10, trending downward | Cross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability. |
|
||||
| `ebft/reward` | Trending upward (may start negative) | Combined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment. |
|
||||
| `grad_norm` | 0.1 -- 3.0 | Gradient magnitude. Values of 0.0 indicate zero-advantage skip (normal). Values above 10 suggest instability. |
|
||||
| `entropy` | 0.05 -- 0.5 | Policy entropy. Values below 0.01 suggest mode collapse. |
|
||||
| `IS ratio min` | Above 0.1 | Importance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase `vllm_sync_interval`. |
|
||||
|
||||
### Console Log Example
|
||||
|
||||
During training, you will see periodic EBFT reward logs:
|
||||
|
||||
```
|
||||
ebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^
|
||||
```
|
||||
|
||||
The arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|-------------|-----|
|
||||
| `alignment` stays below 0.1 | Feature layers not capturing useful information | Try different `feature_layers` or `embed_method` |
|
||||
| `diversity` exceeds 1.0 | Mode collapse -- generations are too similar | Increase `diversity_coef` or `temperature` |
|
||||
| `reward` stuck at -1.0 | Diversity penalty dominates alignment | Reduce `diversity_coef` or increase `alignment_coef` |
|
||||
| `grad_norm` consistently 0.0 | All micro-batches have zero advantage | Increase `num_generations` or check data quality |
|
||||
| `CheckpointError` in strided mode | Incompatible gradient checkpointing settings | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||
| OOM during training | Logits tensor too large | Reduce `sequence_len` or `micro_batch_size`; strided mode uses chunked lm_head to mitigate this |
|
||||
| vLLM 500 errors | `truncate_prompt_tokens` not supported | Ensure you are using `axolotl vllm-serve` (not `trl vllm-serve`) |
|
||||
|
||||
### Feature Network Memory
|
||||
|
||||
In PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the `disable_adapter()` context manager. This saves an entire model copy in VRAM (approximately 1--16 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.
|
||||
|
||||
::: {.callout-note}
|
||||
The `disable_adapter()` approach relies on an invariant: `merge_adapter()` is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.
|
||||
:::
|
||||
|
||||
## Examples
|
||||
|
||||
Complete example configurations are available in `examples/ebft/`:
|
||||
|
||||
| Config | Model | Mode | Description |
|
||||
|--------|-------|------|-------------|
|
||||
| `llama-1b-ebft-strided-structured.yaml` | Llama 3.2 1B | Strided | Single-GPU strided training on code data |
|
||||
| `qwen3-4b-ebft-structured.yaml` | Qwen3 4B | Structured (sync) | Two-GPU structured training |
|
||||
| `qwen3-4b-ebft-structured-async.yaml` | Qwen3 4B | Structured (async) | Two-GPU async training with prefetch |
|
||||
| `qwen3-8b-ebft-structured.yaml` | Qwen3 8B | Structured (sync) | Two-GPU structured training for larger model |
|
||||
| `qwen35-4b-ebft-structured.yaml` | Qwen3.5 4B | Structured (sync) | Two-GPU with Qwen3.5 |
|
||||
| `qwen35-4b-ebft-structured-async.yaml` | Qwen3.5 4B | Structured (async) | Two-GPU async with Qwen3.5 |
|
||||
| `qwen35-9b-ebft-structured.yaml` | Qwen3.5 9B | Structured (sync) | Two-GPU structured for 9B model |
|
||||
@@ -1,67 +0,0 @@
|
||||
---
|
||||
title: "MoE Expert Quantization"
|
||||
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
|
||||
---
|
||||
|
||||
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
|
||||
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
|
||||
weights being loaded in full bf16 precision and causing massive VRAM usage.
|
||||
|
||||
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
|
||||
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
|
||||
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
|
||||
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
|
||||
|
||||
## Usage
|
||||
|
||||
Enable expert quantization in your Axolotl config:
|
||||
|
||||
```yaml
|
||||
quantize_moe_experts: true
|
||||
```
|
||||
|
||||
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
|
||||
|
||||
### Expert LoRA targeting
|
||||
|
||||
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`lora_dropout` must be `0` when using `lora_target_parameters`.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
|
||||
- CUDA GPUs only (not tested with ROCm or other backends)
|
||||
- FSDP2 compatible for distributed training
|
||||
|
||||
## Limitations
|
||||
|
||||
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
|
||||
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
|
||||
- Total model parameter count may display incorrectly (trainable param count is correct).
|
||||
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
|
||||
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
|
||||
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
|
||||
- DeepSpeed has not been tested.
|
||||
|
||||
## Implementation details
|
||||
|
||||
The quantization is applied by patching transformers to intercept weight loading.
|
||||
When a 3D+ CUDA tensor with "expert" in its name is detected:
|
||||
|
||||
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
|
||||
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
|
||||
|
||||
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
|
||||
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
|
||||
|
||||
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).
|
||||
80
docs/faq.qmd
80
docs/faq.qmd
@@ -9,11 +9,11 @@ description: Frequently asked questions
|
||||
|
||||
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
||||
|
||||
**Q: exitcode: -9**
|
||||
**Q: Exitcode -9**
|
||||
|
||||
> A: This usually happens when you run out of system RAM.
|
||||
|
||||
**Q: exitcode: -7 while using deepspeed**
|
||||
**Q: Exitcode -7 while using deepspeed**
|
||||
|
||||
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
@@ -51,26 +51,6 @@ description: Frequently asked questions
|
||||
> pad_token: "..."
|
||||
> ```
|
||||
|
||||
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
|
||||
|
||||
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
|
||||
|
||||
**Q: vLLM is not working with Axolotl**
|
||||
|
||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
||||
|
||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||
|
||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||
|
||||
**Q: Can we mix text and text+image datasets for VLM training?**
|
||||
|
||||
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
|
||||
|
||||
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
|
||||
|
||||
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
@@ -93,62 +73,10 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is likely an empty turn.
|
||||
|
||||
**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**
|
||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||
|
||||
> A: There can be two reasons:
|
||||
|
||||
> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.
|
||||
|
||||
> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.
|
||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||
|
||||
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||
|
||||
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||
|
||||
**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**
|
||||
|
||||
> A: There can be two reasons:
|
||||
|
||||
> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.
|
||||
|
||||
> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.
|
||||
|
||||
**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**
|
||||
|
||||
> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.
|
||||
|
||||
**Q: `EOT token __ is encoded as multiple tokens.`**
|
||||
|
||||
> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.
|
||||
|
||||
**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**
|
||||
|
||||
> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.
|
||||
|
||||
**Q: If `eot_tokens: ` is not provided, what happens?**
|
||||
|
||||
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
|
||||
|
||||
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
|
||||
|
||||
**Q: `Data processing error: CAS service error`**
|
||||
|
||||
> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`
|
||||
|
||||
**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**
|
||||
|
||||
> A: Depending on the version of torch, you may need to include this in your YAML:
|
||||
|
||||
> ```yaml
|
||||
> flex_attn_compile_kwargs:
|
||||
> dynamic: false
|
||||
> mode: max-autotune-no-cudagraphs
|
||||
> ```
|
||||
|
||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
||||
|
||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
||||
|
||||
**Q: `Error parsing tool_calls arguments as JSON.`
|
||||
|
||||
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "FSDP + QLoRA"
|
||||
title: "FDSP + QLoRA"
|
||||
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
|
||||
format:
|
||||
html:
|
||||
@@ -20,15 +20,9 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||
|
||||
1. Set `adapter: qlora` in your axolotl config file.
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Enabling Swap for FSDP2
|
||||
|
||||
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
|
||||
|
||||
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
|
||||
|
||||
## Example Config
|
||||
|
||||
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.
|
||||
|
||||
@@ -55,7 +55,7 @@ output_dir: ./outputs/lora-out
|
||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||
:::
|
||||
|
||||
See our [config options](config-reference.qmd) for more details.
|
||||
See our [Config options](config.qmd) for more details.
|
||||
|
||||
### Training {#sec-training}
|
||||
|
||||
@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
|
||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||
format them.
|
||||
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
||||
format):
|
||||
|
||||
```json
|
||||
@@ -120,12 +120,6 @@ axolotl train my_training.yml
|
||||
|
||||
## Common Tasks {#sec-common-tasks}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
The same yaml file is used for training, inference, and merging.
|
||||
|
||||
:::
|
||||
|
||||
### Testing Your Model {#sec-testing}
|
||||
|
||||
After training, test your model:
|
||||
@@ -134,16 +128,6 @@ After training, test your model:
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
More details can be found in [Inference](inference.qmd).
|
||||
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
### Preprocessing Data {#sec-preprocessing}
|
||||
|
||||
For large datasets, preprocess first:
|
||||
@@ -152,44 +136,26 @@ For large datasets, preprocess first:
|
||||
axolotl preprocess my_training.yml
|
||||
```
|
||||
|
||||
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
|
||||
|
||||
### Merging LoRA weights {#sec-merging-lora}
|
||||
|
||||
To merge the LoRA weights back into the base model, run:
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
The merged model will be saved in the `{output_dir}/merged` directory.
|
||||
|
||||
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
||||
|
||||
## Next Steps {#sec-next-steps}
|
||||
|
||||
Now that you have the basics, explore these guides based on what you want to do:
|
||||
Now that you have the basics, you might want to:
|
||||
|
||||
**Choose your path:**
|
||||
- Try different model architectures
|
||||
- Experiment with hyperparameters
|
||||
- Use more advanced training methods
|
||||
- Scale up to larger models
|
||||
|
||||
- [Choosing a Fine-Tuning Method](choosing_method.qmd) — SFT vs LoRA vs QLoRA vs GRPO vs DPO, with hardware recommendations
|
||||
Check our other guides for details on these topics:
|
||||
|
||||
**Core guides:**
|
||||
|
||||
- [Dataset Loading](dataset_loading.qmd) — Loading datasets from various sources
|
||||
- [Dataset Formats](dataset-formats) — Working with different data formats
|
||||
- [Optimizations](optimizations.qmd) — Flash attention, gradient checkpointing, sample packing
|
||||
- [Training Stability & Debugging](training_stability.qmd) — Monitoring metrics, fixing NaN, OOM debugging
|
||||
|
||||
**Advanced training methods:**
|
||||
|
||||
- [RLHF / Preference Learning](rlhf.qmd) — DPO, KTO, GRPO, EBFT
|
||||
- [GRPO Training](grpo.qmd) — RL with custom rewards and vLLM generation
|
||||
- [vLLM Serving](vllm_serving.qmd) — Setting up vLLM for GRPO
|
||||
|
||||
**Scaling up:**
|
||||
|
||||
- [Multi-GPU Training](multi-gpu.qmd) — DeepSpeed, FSDP, DDP
|
||||
- [Multi-Node Training](multi-node.qmd) — Distributed training across machines
|
||||
- [Configuration Guide](config.qmd) - Full configuration options
|
||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||
- [Multi-GPU Training](multi-gpu.qmd)
|
||||
- [Multi-Node Training](multi-node.qmd)
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
---
|
||||
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
|
||||
---
|
||||
|
||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||
models by reducing the memory footprint and improving computational efficiency.
|
||||
|
||||
### Enabling Gradient Checkpointing
|
||||
|
||||
```yaml
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
### Enabling Activation Offloading
|
||||
|
||||
```yaml
|
||||
gradient_checkpointing: true # required for activation offloading
|
||||
activation_offloading: true
|
||||
```
|
||||
|
||||
Activation offloading variants:
|
||||
|
||||
The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
|
||||
to overlap the communications and computations when offloading.
|
||||
|
||||
The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
|
||||
|
||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||
|
||||
### Enabling Layer Offloading
|
||||
|
||||
```yaml
|
||||
layer_offloading: true
|
||||
```
|
||||
|
||||
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
|
||||
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
|
||||
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
|
||||
trainable adapter weights stay on GPU permanently.
|
||||
|
||||
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
|
||||
|
||||
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
|
||||
prefetched asynchronously on a separate CUDA stream for overlap.
|
||||
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
|
||||
previous layer is prefetched.
|
||||
|
||||
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
|
||||
|
||||
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
|
||||
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
|
||||
that is kept on GPU at any given time.
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- CUDA GPU (CPU-only training is not supported for this feature)
|
||||
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
|
||||
- Best combined with LoRA/QLoRA where most parameters are frozen
|
||||
611
docs/grpo.qmd
611
docs/grpo.qmd
@@ -1,611 +0,0 @@
|
||||
---
|
||||
title: "GRPO Training"
|
||||
description: "Group Relative Policy Optimization — a reinforcement learning method for training language models with verifiable reward functions."
|
||||
order: 8
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).
|
||||
|
||||
Use GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.
|
||||
|
||||
Axolotl's GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.
|
||||
|
||||
|
||||
## Architecture
|
||||
|
||||
GRPO training uses a two-process architecture: a vLLM server for fast generation and a trainer process for scoring and gradient updates.
|
||||
|
||||
```
|
||||
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
|
||||
┌──────────────────────┐ ┌──────────────────────────────────┐
|
||||
│ vLLM Server │ │ Trainer │
|
||||
│ │ HTTP │ │
|
||||
│ Serves base model │◄────────────►│ Background thread: │
|
||||
│ + LoRA adapter │ /generate │ Send prompts to vLLM │
|
||||
│ │ /set_lora │ Pad & collate completions │
|
||||
│ Punica kernels for │ │ │
|
||||
│ LoRA inference │ │ Main thread: │
|
||||
│ │ │ Score completions (rewards) │
|
||||
└──────────────────────┘ │ Compute policy log-probs │
|
||||
│ Calculate advantages │
|
||||
│ PPO-clip gradient update │
|
||||
│ Sync LoRA weights to vLLM │
|
||||
└──────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Data flow for each training step:**
|
||||
|
||||
1. The background thread sends prompts to vLLM, which generates `num_generations` completions per prompt.
|
||||
2. The main thread scores completions using your reward functions.
|
||||
3. Advantages are computed within each prompt group (group-relative normalization).
|
||||
4. Policy log-probabilities are computed by running a forward pass on the training model.
|
||||
5. The PPO-clip loss is computed and gradients are applied.
|
||||
6. Periodically, LoRA adapter weights are synced back to vLLM so future generations reflect the updated policy.
|
||||
|
||||
With async prefetch enabled, step 1 for the *next* batch runs concurrently with steps 2-6 for the *current* batch.
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
A GRPO training run requires three components: a YAML config, a reward module (Python file), and a running vLLM server.
|
||||
|
||||
### 1. Write a reward module
|
||||
|
||||
Create a file called `rewards.py` in your working directory:
|
||||
|
||||
```python
|
||||
# rewards.py
|
||||
import re
|
||||
|
||||
|
||||
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
|
||||
"""Check if the completion contains the correct numerical answer."""
|
||||
rewards = []
|
||||
for completion, correct in zip(completions, answer):
|
||||
text = completion[0]["content"]
|
||||
# Extract the last number from the completion
|
||||
numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
|
||||
predicted = numbers[-1] if numbers else ""
|
||||
rewards.append(1.0 if predicted == str(correct) else 0.0)
|
||||
return rewards
|
||||
|
||||
|
||||
def format_reward(completions, **kwargs) -> list[float]:
|
||||
"""Reward completions that use a structured thinking format."""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
text = completion[0]["content"]
|
||||
has_think = "<think>" in text and "</think>" in text
|
||||
has_answer = "<answer>" in text and "</answer>" in text
|
||||
rewards.append(1.0 if has_think and has_answer else 0.0)
|
||||
return rewards
|
||||
|
||||
|
||||
def prompt_transform(cfg, *args, **kwargs):
|
||||
"""Convert GSM8K dataset rows into chat prompts."""
|
||||
def transform_fn(example, tokenizer=None):
|
||||
label = example["answer"].split("####")[-1].strip().replace(",", "")
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": "Solve the math problem. Show your reasoning in <think> tags and your final numerical answer in <answer> tags."},
|
||||
{"role": "user", "content": example["question"]},
|
||||
],
|
||||
"answer": label,
|
||||
}
|
||||
return transform_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
### 2. Write the config
|
||||
|
||||
Create `config.yaml`:
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
rl: grpo
|
||||
chat_template: tokenizer_default
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.85
|
||||
dtype: auto
|
||||
max_model_len: 2048
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
use_data_producer: true
|
||||
vllm_server_host: 0.0.0.0
|
||||
vllm_server_port: 8000
|
||||
vllm_server_timeout: 300
|
||||
vllm_lora_sync: true
|
||||
num_generations: 8
|
||||
max_completion_length: 512
|
||||
temperature: 0.7
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
- rewards.format_reward
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.5
|
||||
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
type: rewards.prompt_transform
|
||||
split: train
|
||||
|
||||
skip_prepare_dataset: true
|
||||
val_set_size: 0.0
|
||||
sequence_len: 512
|
||||
micro_batch_size: 2
|
||||
gradient_accumulation_steps: 4
|
||||
max_steps: 200
|
||||
learning_rate: 5.0e-6
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
warmup_steps: 10
|
||||
|
||||
bf16: true
|
||||
flash_attention: true
|
||||
gradient_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
|
||||
output_dir: ./grpo-output
|
||||
logging_steps: 1
|
||||
```
|
||||
|
||||
### 3. Start vLLM and train
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM server on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Wait 30-90 seconds for model loading and CUDA graph capture
|
||||
|
||||
# Terminal 2: Train on GPU 1
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
:::{.callout-tip}
|
||||
Use `tmux` or separate terminal sessions to manage the two processes. The vLLM server must remain running for the entire training duration.
|
||||
:::
|
||||
|
||||
|
||||
## Custom Reward Functions
|
||||
|
||||
### Function signature
|
||||
|
||||
TRL calls reward functions with this signature:
|
||||
|
||||
```python
|
||||
def my_reward(completions, **kwargs) -> list[float]:
|
||||
```
|
||||
|
||||
- `completions` is a list of single-element lists, where each element is a dict `{"role": "assistant", "content": "..."}`. So `completions[i][0]["content"]` gives you the text of the i-th completion.
|
||||
- `**kwargs` contains all dataset columns that were *not* removed by the dataset transform. This is how you pass ground truth answers, metadata, or any other information to your reward function.
|
||||
- Return a `list[float]` with the same length as `completions`. You may return `None` for individual elements to exclude them from aggregation.
|
||||
|
||||
### Example: accuracy reward with answer extraction
|
||||
|
||||
```python
|
||||
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
|
||||
rewards = []
|
||||
for completion, correct_answer in zip(completions, answer):
|
||||
text = completion[0]["content"]
|
||||
# Extract answer from <answer>...</answer> tags
|
||||
match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
|
||||
predicted = match.group(1).strip() if match else ""
|
||||
rewards.append(1.0 if predicted == str(correct_answer) else 0.0)
|
||||
return rewards
|
||||
```
|
||||
|
||||
### Example: length penalty
|
||||
|
||||
```python
|
||||
def length_penalty(completions, **kwargs) -> list[float]:
|
||||
"""Penalize very short or very long completions."""
|
||||
rewards = []
|
||||
for completion in completions:
|
||||
length = len(completion[0]["content"])
|
||||
if length < 50:
|
||||
rewards.append(-0.5)
|
||||
elif length > 2000:
|
||||
rewards.append(-0.2)
|
||||
else:
|
||||
rewards.append(0.0)
|
||||
return rewards
|
||||
```
|
||||
|
||||
### Multiple rewards and weighting
|
||||
|
||||
You can combine multiple reward functions with different weights:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
- rewards.format_reward
|
||||
- rewards.length_penalty
|
||||
reward_weights:
|
||||
- 1.0 # accuracy is most important
|
||||
- 0.5 # format compliance
|
||||
- 0.1 # mild length preference
|
||||
```
|
||||
|
||||
Rewards are combined by the `multi_objective_aggregation` strategy:
|
||||
|
||||
- `sum_then_normalize` (default): weights and sums all rewards first, then normalizes across the group.
|
||||
- `normalize_then_sum` (GDPO): normalizes each reward independently, then sums. This prevents one reward from dominating and is recommended when using multiple reward functions with different scales.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
multi_objective_aggregation: normalize_then_sum
|
||||
```
|
||||
|
||||
### Dataset transforms
|
||||
|
||||
The dataset transform converts raw HuggingFace dataset rows into chat-format prompts:
|
||||
|
||||
```python
|
||||
def prompt_transform(cfg, *args, **kwargs):
|
||||
def map_fn(example, tokenizer=None):
|
||||
return {
|
||||
"prompt": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": example["question"]},
|
||||
],
|
||||
# Keep 'answer' column for the reward function
|
||||
"answer": example["answer"],
|
||||
}
|
||||
# Remove columns consumed by the transform; keep columns needed by rewards
|
||||
return map_fn, {"remove_columns": ["question"]}
|
||||
```
|
||||
|
||||
The transform returns a tuple of `(map_function, kwargs_dict)`. The `remove_columns` in the kwargs dict removes columns that are no longer needed. Columns that your reward functions reference via `**kwargs` (like `answer`) must *not* be removed.
|
||||
|
||||
:::{.callout-warning}
|
||||
The reward module must be importable from the directory where you run `axolotl train`. If your reward file is `rewards.py`, the import path is `rewards.accuracy_reward`. If it is inside a package `my_rewards/scoring.py`, use `my_rewards.scoring.accuracy_reward`.
|
||||
:::
|
||||
|
||||
### Reward models (neural network rewards)
|
||||
|
||||
Instead of a Python function, you can pass a HuggingFace model path as a reward function. TRL will load it as a reward model and use its scalar output as the reward:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_funcs:
|
||||
- OpenAssistant/reward-model-deberta-v3-large-v2
|
||||
- rewards.format_reward
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.3
|
||||
```
|
||||
|
||||
### Using math_verify
|
||||
|
||||
The `math_verify` library provides robust mathematical answer verification but uses `signal.alarm()` internally, which only works in the main thread. If you use `math_verify` in a reward function, set `reward_num_workers` to use subprocess workers:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4
|
||||
```
|
||||
|
||||
Each worker runs in its own subprocess with its own main thread, so `signal.alarm()` works correctly.
|
||||
|
||||
|
||||
## vLLM Setup
|
||||
|
||||
GRPO requires a running vLLM server for generation. For a complete guide on server modes, LoRA sync, weight synchronization, and restart procedures, see [vLLM Serving](vllm_serving.qmd).
|
||||
|
||||
The minimal setup:
|
||||
|
||||
```yaml
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.85
|
||||
|
||||
trl:
|
||||
use_vllm: true
|
||||
vllm_lora_sync: true # Recommended with LoRA — faster sync, no NCCL contention
|
||||
vllm_sync_interval: 5 # Sync weights every 5 steps
|
||||
```
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # GPU 0: vLLM
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml # GPU 1: training
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
vLLM must be restarted between experiments — stale weight syncs corrupt server state. See [Restart Requirements](vllm_serving.qmd#sec-restart).
|
||||
:::
|
||||
|
||||
|
||||
## Async Training Features
|
||||
|
||||
Async GRPO overlaps generation and training to reduce wall-clock time. While the model trains on the current batch, the next batch is already being generated by vLLM.
|
||||
|
||||
### Enabling async prefetch
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
use_data_producer: true
|
||||
async_prefetch: true
|
||||
prefetch_depth: 1
|
||||
vllm_sync_interval: 2
|
||||
```
|
||||
|
||||
- `use_data_producer: true` enables the data producer protocol (required for all async features).
|
||||
- `async_prefetch: true` runs generation in a background thread.
|
||||
- `prefetch_depth` controls how many batches to prefetch ahead (1 is usually sufficient).
|
||||
- `vllm_sync_interval` controls how often LoRA weights are synced to vLLM (every N optimizer steps). Lower values mean fresher generations but more sync overhead.
|
||||
|
||||
:::{.callout-tip}
|
||||
Because the background thread generates with slightly stale model weights, async mode benefits from importance sampling correction (see next section). Enable `vllm_importance_sampling_correction: true` when using `async_prefetch: true`.
|
||||
:::
|
||||
|
||||
### Streaming partial batch
|
||||
|
||||
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This reduces peak memory during scoring and enables finer-grained zero-advantage skipping.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
streaming_partial_batch: true
|
||||
streaming_min_groups: 1
|
||||
```
|
||||
|
||||
`streaming_min_groups` controls the minimum number of prompt groups scored per chunk. Setting it to 1 gives maximum granularity.
|
||||
|
||||
### Zero-advantage batch skipping
|
||||
|
||||
When all advantages in a micro-batch are zero (every completion in the group got the same reward), there is no learning signal. This feature skips the forward/backward pass entirely for such micro-batches.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
skip_zero_advantage_batches: true # default
|
||||
```
|
||||
|
||||
This is enabled by default and logged as `skipped_zero_adv_batches` in training metrics. It is a safety net, not a major optimization -- it only saves significant time when the model cannot solve any prompts in the batch.
|
||||
|
||||
### Replay buffer
|
||||
|
||||
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and replaces zero-signal groups in later batches. This improves data utilization when many prompts yield no reward variance.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
replay_buffer_size: 100
|
||||
replay_recompute_logps: true
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
When `replay_recompute_logps: false`, replayed data uses stale log-probabilities which creates an IS mismatch. Keep the default `true` unless you have a specific reason to disable it.
|
||||
:::
|
||||
|
||||
### Deferred re-rolling
|
||||
|
||||
Prompts where the model gets zero reward for all generations are buffered and re-injected into later batches, when the model may have improved enough to produce useful completions.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
||||
reroll_max_groups: 1 # Max groups to replace per batch
|
||||
```
|
||||
|
||||
Set `reroll_start_fraction: 1.0` to disable. This is most useful for tasks where the model starts weak but steadily improves.
|
||||
|
||||
### Parallel reward workers
|
||||
|
||||
Reward functions that use `signal.alarm()` (like `math_verify`) only work in the main thread. Parallel reward workers run each function in its own subprocess:
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4
|
||||
```
|
||||
|
||||
Work is sharded across workers by prompt group. For simple reward functions, a single worker is usually sufficient -- the overhead of IPC can exceed the computation time.
|
||||
|
||||
|
||||
## Importance Sampling and Off-Policy Correction
|
||||
|
||||
When using async prefetch, completions are generated from a slightly older policy. IS correction adjusts the gradient to account for this mismatch.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_importance_sampling_correction: true
|
||||
importance_sampling_level: token # 'token' recommended (especially with Liger kernel)
|
||||
off_policy_mask_threshold: 0.5 # KL threshold — masks sequences that are too off-policy
|
||||
```
|
||||
|
||||
Use `token` level IS. Sequence-level has numerical issues with Liger's chunked computation. The `off_policy_mask_threshold` (OPSM) is a safety net that drops sequences where KL divergence exceeds the threshold — 0.5 is a reasonable starting point.
|
||||
|
||||
For detailed coverage of IS modes (`token_mask`, `token_truncate`, etc.), capping, and bias-corrected KL, see [vLLM Serving — IS Correction](vllm_serving.qmd#sec-weight-sync).
|
||||
|
||||
|
||||
## Scaling
|
||||
|
||||
### FP8 training
|
||||
|
||||
FP8 quantization halves model VRAM usage with minimal impact on training quality. It does not significantly speed up computation for small models but allows larger models to fit in memory.
|
||||
|
||||
```yaml
|
||||
fp8: true
|
||||
torch_compile: true
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
FP8 requires patching for zero-padding edge cases. The `act_quant_kernel` can produce NaN when input is all zeros (padding positions). If you see NaN in grad norms, check whether your padding token embedding is non-zero.
|
||||
:::
|
||||
|
||||
### FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
FSDP distributes model parameters across multiple GPUs for training while vLLM runs on a separate GPU:
|
||||
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
# GPU 0: vLLM
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# GPUs 0,1: Training (FSDP will use both visible GPUs)
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
:::{.callout-warning}
|
||||
`async_prefetch: true` can deadlock with FSDP because background threads perform unsynchronized FSDP collectives across ranks. With multi-GPU FSDP, only rank 0 generates in the background thread and results are broadcast to all ranks. If you still see hangs, set `async_prefetch: false`.
|
||||
:::
|
||||
|
||||
### DeepSpeed ZeRO-3
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required -- non-reentrant causes CheckpointError with ZeRO-3
|
||||
```
|
||||
|
||||
:::{.callout-note}
|
||||
DeepSpeed ZeRO-3 requires `use_reentrant: true` for gradient checkpointing. This is the opposite of the FSDP recommendation. Non-reentrant checkpointing causes tensor metadata mismatches during recomputation with ZeRO-3's parameter partitioning.
|
||||
:::
|
||||
|
||||
### Multi-GPU considerations
|
||||
|
||||
| Concern | Recommendation |
|
||||
|---------|---------------|
|
||||
| vLLM GPU allocation | Dedicate one or more GPUs to vLLM; do not share with trainer GPUs |
|
||||
| Weight sync contention | Use `vllm_lora_sync: true` to avoid NCCL contention between training and vLLM |
|
||||
| FSDP + async | Use `async_prefetch: false` or rely on rank-0-only background generation |
|
||||
| DeepSpeed + gradient checkpoint | Must use `use_reentrant: true` |
|
||||
| OOM during scoring | Reduce `micro_batch_size` or `num_generations`. The logits tensor scales with `batch_size * vocab_size` |
|
||||
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
For detailed metric ranges, failure diagnosis, and OOM debugging, see [Training Stability & Debugging](training_stability.qmd).
|
||||
|
||||
Quick health checks during GRPO training:
|
||||
|
||||
- `rewards/*/mean` should be > 0.15 within 20 steps — if it stays at 0, test your reward function standalone
|
||||
- `reward_std` should be > 0 on most steps — all-zero means no learning signal
|
||||
- `entropy` in 0.05-0.5 — below 0.01 suggests mode collapse
|
||||
- `grad_norm` in 0.001-1.0 — > 10 is unstable, 0.0 is expected when zero-advantage skip fires
|
||||
|
||||
:::{.callout-tip}
|
||||
Pipe training output to a log file: `axolotl train config.yaml 2>&1 | tee /tmp/training.log`
|
||||
:::
|
||||
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
All GRPO-specific options live under the `trl:` key in your config. Standard training options (`learning_rate`, `micro_batch_size`, etc.) are set at the top level as usual.
|
||||
|
||||
### Core GRPO
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `use_vllm` | bool | `false` | Enable vLLM for generation |
|
||||
| `vllm_mode` | `"server"` or `"colocate"` | `null` | vLLM deployment mode |
|
||||
| `vllm_server_host` | str | `"0.0.0.0"` | vLLM server hostname |
|
||||
| `vllm_server_port` | int | `8000` | vLLM server port |
|
||||
| `vllm_server_timeout` | int | `null` | Timeout (seconds) for vLLM responses |
|
||||
| `num_generations` | int | `null` | Completions generated per prompt |
|
||||
| `generation_batch_size` | int | `null` | Number of unique prompts per generation step |
|
||||
| `max_completion_length` | int | `null` | Maximum tokens per completion |
|
||||
| `beta` | float | `null` | KL penalty coefficient |
|
||||
| `num_iterations` | int | `null` | Iterations per batch (mu in the GRPO paper) |
|
||||
| `epsilon` | float | `null` | PPO clipping lower bound |
|
||||
| `epsilon_high` | float | `null` | PPO clipping upper bound |
|
||||
| `loss_type` | str | `null` | Loss formulation: `grpo`, `bnpo`, or `dr_grpo` |
|
||||
| `scale_rewards` | bool | `true` | Normalize rewards by standard deviation |
|
||||
| `mask_truncated_completions` | bool | `false` | Exclude truncated completions from loss |
|
||||
|
||||
### Reward functions
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `reward_funcs` | list[str] | `null` | Import paths to reward functions or HF model IDs |
|
||||
| `reward_weights` | list[float] | `null` | Relative weights for each reward function |
|
||||
| `multi_objective_aggregation` | str | `null` | `"sum_then_normalize"` (GRPO) or `"normalize_then_sum"` (GDPO) |
|
||||
| `rollout_func` | str | `null` | Import path to custom rollout function for OpenEnv-style tasks |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `temperature` | float | `null` | Sampling temperature |
|
||||
| `top_p` | float | `null` | Nucleus sampling probability |
|
||||
| `top_k` | int | `null` | Top-k sampling |
|
||||
| `min_p` | float | `null` | Minimum probability threshold |
|
||||
| `repetition_penalty` | float | `null` | Penalty for repeated tokens |
|
||||
| `generation_kwargs` | dict | `null` | Additional vLLM SamplingParams (e.g., `stop_token_ids`) |
|
||||
| `chat_template_kwargs` | dict | `null` | Chat template kwargs (e.g., `{enable_thinking: false}`) |
|
||||
| `vllm_guided_decoding_regex` | str | `null` | Regex constraint for guided decoding |
|
||||
|
||||
### Async pipeline
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `use_data_producer` | bool | `false` | Enable data producer protocol (required for async features) |
|
||||
| `async_prefetch` | bool | `false` | Generate next batch in background thread |
|
||||
| `prefetch_depth` | int | `null` | Number of batches to prefetch ahead |
|
||||
| `vllm_sync_interval` | int | `null` | Sync LoRA weights to vLLM every N steps |
|
||||
| `vllm_lora_sync` | bool | `false` | Use filesystem LoRA sync instead of NCCL merge |
|
||||
| `streaming_partial_batch` | bool | `null` | Score prompt groups incrementally |
|
||||
| `streaming_min_groups` | int | `null` | Minimum groups per streaming chunk |
|
||||
| `skip_zero_advantage_batches` | bool | `true` | Skip micro-batches with zero learning signal |
|
||||
| `reward_num_workers` | int | `1` | Subprocess workers for reward computation |
|
||||
| `vllm_enable_sleep_mode` | bool | `null` | Offload vLLM weights when idle (colocate mode) |
|
||||
|
||||
### Importance sampling
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `vllm_importance_sampling_correction` | bool | `null` | Enable IS correction for async distribution shift |
|
||||
| `importance_sampling_level` | `"token"` or `"sequence"` | `null` | Granularity of IS ratios. Use `token` with Liger |
|
||||
| `vllm_importance_sampling_mode` | str | `null` | `token_mask`, `token_truncate`, `sequence_mask`, or `sequence_truncate` |
|
||||
| `vllm_importance_sampling_cap` | float | `null` | Cap C for IS ratio clipping/masking |
|
||||
| `off_policy_mask_threshold` | float | `null` | KL threshold for off-policy sequence masking (OPSM) |
|
||||
| `use_bias_correction_kl` | bool | `null` | Apply IS correction to KL divergence term |
|
||||
|
||||
### Replay and re-roll
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `replay_buffer_size` | int | `0` | Max cached high-signal groups. 0 = disabled |
|
||||
| `replay_recompute_logps` | bool | `true` | Recompute log-probs for replayed data with current model |
|
||||
| `reroll_start_fraction` | float | `1.0` | Start re-rolling failed prompts after this fraction of training. 1.0 = disabled |
|
||||
| `reroll_max_groups` | int | `1` | Max prompt groups to replace with re-rolls per batch |
|
||||
|
||||
### Reference model
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `sync_ref_model` | bool | `false` | Periodically sync reference model with training model |
|
||||
| `ref_model_mixup_alpha` | float | `0.9` | EMA coefficient for reference model sync |
|
||||
| `ref_model_sync_steps` | int | `64` | Sync reference model every N steps |
|
||||
|
||||
### Logging
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `log_completions` | bool | `false` | Log sample completions to W&B |
|
||||
| `num_completions_to_print` | int | `null` | Number of completions to print per step |
|
||||
| `use_liger_loss` | bool | `null` | Use Liger fused kernel for GRPO loss (reduces VRAM) |
|
||||
@@ -14,21 +14,11 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
## Requirements {#sec-requirements}
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- PyTorch ≥2.6.0
|
||||
- Python ≥3.10
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
|
||||
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
|
||||
```{.bash}
|
||||
@@ -41,40 +31,6 @@ installed) in order not to clobber it, and so that we set the correct version of
|
||||
dependencies that are specific to the PyTorch version or other installed
|
||||
co-dependencies.
|
||||
|
||||
### uv Installation {#sec-uv}
|
||||
|
||||
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
|
||||
|
||||
Install uv if not already installed
|
||||
```{.bash}
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
```
|
||||
|
||||
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
|
||||
then create the venv and activate
|
||||
```{.bash}
|
||||
export UV_TORCH_BACKEND=cu126
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Install PyTorch
|
||||
- PyTorch 2.6.0 recommended
|
||||
```{.bash}
|
||||
uv pip install packaging setuptools wheel
|
||||
uv pip install torch==2.6.0
|
||||
uv pip install awscli pydantic
|
||||
```
|
||||
|
||||
Install axolotl from PyPi
|
||||
```{.bash}
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
|
||||
|
||||
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
|
||||
```
|
||||
|
||||
### Edge/Development Build {#sec-edge-build}
|
||||
|
||||
For the latest features between releases:
|
||||
@@ -110,10 +66,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
```
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
|
||||
## Cloud Environments {#sec-cloud}
|
||||
@@ -124,17 +76,14 @@ For providers supporting Docker:
|
||||
|
||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
||||
- Available on:
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
||||
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
|
||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
|
||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
||||
|
||||
### Google Colab {#sec-colab}
|
||||
|
||||
[](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
|
||||
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
|
||||
|
||||
## Platform-Specific Instructions {#sec-platform-specific}
|
||||
|
||||
@@ -156,7 +105,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
|
||||
### Conda/Pip venv {#sec-conda}
|
||||
|
||||
1. Install Python ≥3.11
|
||||
1. Install Python ≥3.10
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
@@ -165,7 +114,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
```{.bash}
|
||||
hf auth login
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Troubleshooting {#sec-troubleshooting}
|
||||
|
||||
@@ -5,11 +5,10 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
|
||||
|
||||
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
|
||||
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
|
||||
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU
|
||||
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom
|
||||
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order
|
||||
to improve speed and reduce memory usage during the forward and backward passes of
|
||||
these calculations.
|
||||
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
|
||||
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
|
||||
to leverage operator fusion and tensor re-use in order to improve speed and reduce
|
||||
memory usage during the forward and backward passes of these calculations.
|
||||
|
||||
We currently support several common model architectures, including (but not limited to):
|
||||
|
||||
@@ -85,14 +84,6 @@ lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
LoRA kernels do not support remote modeling code.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
@@ -136,5 +127,6 @@ computation path.
|
||||
## Future Work
|
||||
|
||||
- Support for additional model architectures
|
||||
- Support for the FSDP setting
|
||||
- Support for dropout and bias
|
||||
- Additional operator fusions
|
||||
|
||||
@@ -27,9 +27,3 @@ learning_rate: 2e-5
|
||||
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
|
||||
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
|
||||
self attention `q_proj` module.
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
|
||||
|
||||
:::
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user