Compare commits

...

35 Commits

Author SHA1 Message Date
salman
e36d3c9f30 Merge branch 'main' into testingci 2025-07-24 09:13:15 +01:00
Salman Mohammadi
53614391ed wip 2025-07-24 09:12:55 +01:00
salman
1407aac779 Skip CI for draft PRs (#2970) 2025-07-24 09:11:46 +01:00
Dan Saunders
b34c3371ed upgrade torchao (#2968) 2025-07-23 10:27:28 -04:00
Wing Lian
5f1a4306b0 don't check dataset labels during preprocess for GRPO (#2952) [skip ci]
* don't check dataset labels during preprocess for GRPO

* use enum check per PR feedback
2025-07-22 20:40:44 -04:00
Wing Lian
93709eb5ce handle refactor upstream for flash attention (#2966) 2025-07-22 20:40:04 -04:00
Dan Saunders
208fb7b8e7 basic torchao fp8 mixed precision training (#2926)
* debug

* debug

* debug

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint

* accelerator constructor patch for single-GPU torch fp8

* lint

* re-using existing fp8 code

* lint

* remove accelerate patch now fix in latest release

* fix

* docs

* add fp8 + fsdp2 example

* remove unused config

* update config

* smoke tests

* add validator

* add 2.7.0 guard for fsdp2

* fix

* add config descriptions

* add FSDP doc link

* nit

* set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather

* better cfg for smoke tests

* add test for accelerate patching

* update fp8 validator
2025-07-22 16:27:47 -04:00
Wing Lian
b86a1d47b0 we don't need to call check_dataset_labels when skip_prepare_dataset is set (#2962)
* we don't need to call check_dataset_labels when skip_prepare_dataset is set

* Fix actual bug and revert prior fix

* warn and early return instead of raising an error

* use error
2025-07-22 10:00:53 -04:00
NanoCode012
01d8175d48 fix: revert changing default optimizer to muon (#2965) [skip ci] 2025-07-22 10:00:30 -04:00
NanoCode012
631268a0ca revert renaming of deepspeed stage3 args that use auto (#2964) [skip ci]
* Revert "fix deprecate deepspeed stage3_gather_16bit_weights_on_model_save arg…"

This reverts commit e207762928.

* don't revert the values that don't use 'auto'

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-22 09:59:47 -04:00
Wing Lian
3a208cfd84 Autocomplete axolotl CLI (#2955)
* static autocomplete script for axolotl cli

* use list of commands that should autocomplete yaml files

* make sure to chmod the autocomplete script as executable

* shellcheck and fix autocompletion of directory/sub-dirs

* more shellcheck fixes
2025-07-22 08:30:31 -04:00
github-actions[bot]
7267edc168 chore: update pre-commit hooks (#2954) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-07-22 08:30:00 -04:00
NanoCode012
dfba881e99 Feat: add gemma3n support (#2852)
* feat: add gemma3n cce

* feat: add sample config

* feat: add gemma3n multimodal mode

* feat: add audio example

* feat: support audio and return pixel values in collator

* feat: support unmask only assistant region (gemma3n for now)

* feat(doc): add notes for audio loading

* feat: add audio support for gemma3n

* feat: update examples

* feat: add gemma3n to the docs

* fix: add link at top

* feat(doc): clarify additional requirements

* fix: mllama missing aspect ratio

* fix: mllama need attention fixes for fa2

* Partially Revert "fix: mllama need attention fixes for fa2"

This reverts commit a0bfdd1777.

* fix: disable FA2 for mllama in vision mode

* feat: update configs to use proper attention

* fix: support other vision features

* feat(doc): clarify requirements for gemma3n
2025-07-22 16:52:15 +07:00
Wing Lian
d32058e149 include torchvision in build for upstream changes requiring it now (#2953) [skip ci] 2025-07-22 04:19:16 -04:00
NanoCode012
bc1076d8a2 fix: suppress warning if we enabled skip prepare (#2958) 2025-07-21 11:42:04 -04:00
Wing Lian
b7e8f66e5a upstream fixes in cce for dora and tensor paralel support (#2960) [skip ci] 2025-07-21 11:41:53 -04:00
Wing Lian
e207762928 fix deprecate deepspeed stage3_gather_16bit_weights_on_model_save arg (#2956) [skip ci]
* fix deprecate deepspeed stage3_gather_16bit_weights_on_model_save arg

* replace the rest of the migrated deepspeed params
2025-07-21 11:41:31 -04:00
Wing Lian
fefb0797ee better handling for reward function checks for GRPO (#2933) [skip ci]
* better handling for reward function checks for GRPO

* consolidate msg copy
2025-07-21 11:41:15 -04:00
Wing Lian
af8d257aa2 make pad_to_sequence_len default to the same value as sample_packing (#2941) [skip ci]
* make pad_to_sequence_len default to the same value as sample_packing

* remove duplicate validation

* fix test

* update description meta

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-07-21 11:40:56 -04:00
Wing Lian
db5f6f4693 limit num_proc when saving datasets to disk (#2948) [skip ci]
* limit num_proc when saving datasets to disk

* enforce at least 1 in case it rounds down to 0, and sane divisor is at least 8 rows per worker to save

* update fixtures with dataset processes since that should never be NoneType

* improve reusability for tests
2025-07-21 11:39:38 -04:00
Wing Lian
8e5f146701 Fix cloud docker image build and remove apt files for optim (#2961)
* make sure to apt update to install sudo and tmux

* remove apt archives too
2025-07-21 11:05:00 -04:00
Wing Lian
31a15a49b6 add additional packages via apt for better multi-node support (#2949)
* cleanup in Dockerfile and add infiniband packages

* fixes for ci

* fix nightly too
2025-07-20 21:19:23 -04:00
NanoCode012
b986f7c7cb fix: return proper attention for llama4 lora kernel and fsdp2 llama4 example fix (#2943)
* fix: return proper attention for llama4 lora optim

* fix: update fsdp2 llama4 config
2025-07-19 13:54:43 -04:00
salman
e5734e5cf0 adding torchtitan link (#2945) [skip ci] 2025-07-19 13:54:14 -04:00
Wing Lian
109d9c7442 make the initial call to tokenizer.pad not spam the console (#2946) [skip ci]
* make the initial call to tokenizer.pad not spam the console

* add guard from feedback

* make another common console output less verbose

* more logging fixes
2025-07-19 13:53:35 -04:00
Wing Lian
170322a1f0 make sure log level is upper (#2934) 2025-07-17 15:32:55 -04:00
Wing Lian
5f5ae76213 add validation around cce + chunked_ce (#2932) [skip ci]
* add validation around cce + chunked_ce

* return on end of validation method
2025-07-17 15:32:38 -04:00
Wing Lian
a798975b7c coderabbit manual settings (#2940) [skip ci] 2025-07-17 15:32:16 -04:00
Wing Lian
d23f972602 use state for wandb in callbacks (#2930) [skip ci] 2025-07-17 15:31:56 -04:00
Wing Lian
8e41317250 don't use include_tokens_per_second for GRPO (#2931) [skip ci]
* don't use include_tokens_per_second for GRPO

* use blocklist instead
2025-07-17 15:31:21 -04:00
Varun Gumma
9f2bb188a4 Improve Dataset Processing Multiprocessing, Sharding, and Qwen Tokenizer Bug Fix. (#2918)
* Added a feature to save prepared dataset in specified shards, removed limiter on multiprocessing during tokenization, and a bug fix of qwen tokenizer

* removed limiters and fixed config variable name

* black lint

* chore: lint

* feat: update handling of dataset_processes

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-07-17 09:47:58 -04:00
Wing Lian
9dde9e1b71 misc fixes 202507 (#2937) [skip ci]
* misc fixes 202507

* manually handle attn class for llama4
2025-07-17 09:47:45 -04:00
Wing Lian
f2474ef941 bump accelerate to 1.9.0 (#2936) [skip ci] 2025-07-17 09:46:43 -04:00
Wing Lian
8a4bcacdb2 cu126-torch271 for cloud docker image should be tagged with main-latest (#2935) 2025-07-17 00:01:23 -04:00
Wing Lian
d2c3d5a954 run nightly-vs-upstream-main on 2.7.1 and multi-gpu also (#2929) [skip ci] 2025-07-16 21:45:42 -04:00
164 changed files with 1401 additions and 561 deletions

41
.axolotl-complete.bash Normal file
View File

@@ -0,0 +1,41 @@
#!/bin/bash
_axolotl_completions() {
local cur prev
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
# If we're completing the first argument (the command)
if [[ $COMP_CWORD -eq 1 ]]; then
mapfile -t COMPREPLY < <(compgen -W "delinearize-llama4 fetch lm-eval merge-sharded-fsdp-weights quantize vllm-serve evaluate inference merge-lora preprocess train" -- "$cur")
return 0
fi
# Commands that should complete with directories and YAML files
local -a yaml_commands=("merge-sharded-fsdp-weights" "quantize" "vllm-serve" "evaluate" "inference" "merge-lora" "preprocess" "train")
# Check if previous word is in our list
if [[ " ${yaml_commands[*]} " =~ (^|[[:space:]])$prev($|[[:space:]]) ]]; then
# Use filename completion which handles directories properly
compopt -o filenames
mapfile -t COMPREPLY < <(compgen -f -- "$cur")
# Filter to only include directories and YAML files
local -a filtered=()
for item in "${COMPREPLY[@]}"; do
if [[ -d "$item" ]] || [[ "$item" == *.yaml ]] || [[ "$item" == *.yml ]]; then
filtered+=("$item")
fi
done
COMPREPLY=("${filtered[@]}")
return 0
fi
# Default: no completion
return 0
}
# Remove the -o nospace option - let filenames handle it
complete -F _axolotl_completions axolotl

16
.coderabbit.yaml Normal file
View File

@@ -0,0 +1,16 @@
# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json
language: "en-US"
early_access: false
reviews:
profile: "chill"
request_changes_workflow: false
high_level_summary: true
review_status: true
collapse_walkthrough: true
poem: false
sequence_diagrams: false
auto_review:
enabled: true
drafts: false
chat:
auto_reply: true

View File

@@ -17,7 +17,7 @@ on:
jobs:
build-base:
if: github.repository_owner == 'axolotl-ai-cloud'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
@@ -108,7 +108,7 @@ jobs:
PYTORCH_VERSION=${{ matrix.pytorch }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
build-base-uv:
if: github.repository_owner == 'axolotl-ai-cloud'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
strategy:

View File

@@ -3,6 +3,7 @@ on:
# check on PRs, and manual triggers
merge_group:
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
@@ -16,6 +17,7 @@ jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5

View File

@@ -87,7 +87,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -98,6 +97,7 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"

View File

@@ -21,7 +21,7 @@ concurrency:
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
strategy:
fail-fast: false
matrix:

View File

@@ -2,7 +2,7 @@ name: Preview
on:
workflow_dispatch:
pull_request:
types: [opened, synchronize, reopened]
types: [opened, synchronize, reopened, ready_for_review]
# Run the workflow only when one of these files changes
paths:
@@ -25,6 +25,7 @@ permissions:
jobs:
preview:
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- name: Check out repository
uses: actions/checkout@v4

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Update requirements.txt
run: |
@@ -106,6 +106,13 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -130,3 +137,45 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, docker-e2e-tests]
strategy:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 2
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -13,6 +13,7 @@ on:
- 'cicd/cicd.sh'
- 'cicd/Dockerfile.jinja'
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '**.py'
- 'requirements.txt'
@@ -34,6 +35,7 @@ jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
@@ -47,6 +49,7 @@ jobs:
pytest:
name: PyTest
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
# needs: [preload-cache]
strategy:
fail-fast: false
@@ -78,7 +81,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
@@ -121,6 +124,7 @@ jobs:
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
strategy:
fail-fast: false
matrix:
@@ -151,7 +155,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
@@ -185,7 +189,7 @@ jobs:
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
@@ -235,7 +239,7 @@ jobs:
modal run cicd.e2e_tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
@@ -289,6 +293,7 @@ jobs:
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [docker-e2e-tests]
if: ${{ !github.event.pull_request.draft }}
strategy:
fail-fast: false

View File

@@ -27,7 +27,7 @@ repos:
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.1
rev: v1.17.0
hooks:
- id: mypy
additional_dependencies:

View File

@@ -268,6 +268,8 @@ website:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/gradient_accumulation.qmd
- section: "Advanced Features"
contents:

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace

View File

@@ -22,6 +22,7 @@ coverage:
only_pulls: true
flags: null
paths: null
informational: true
patch:
default:
# basic

View File

@@ -7,9 +7,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": "auto"

View File

@@ -7,9 +7,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true

View File

@@ -17,9 +17,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true

View File

@@ -13,9 +13,9 @@
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
"max_live_parameters": 0,
"max_reuse_distance": 0,
"gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true

View File

@@ -10,7 +10,9 @@ ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
@@ -23,17 +25,17 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
fi && \
python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
git config --get remote.origin.fetch && \
git config --global credential.helper store
# helper for huggingface-login cli
RUN git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -16,12 +16,16 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
@@ -31,12 +35,14 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \

View File

@@ -22,18 +22,22 @@ RUN apt-get update \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge

View File

@@ -14,7 +14,10 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \

149
docs/mixed_precision.qmd Normal file
View File

@@ -0,0 +1,149 @@
---
title: "Mixed Precision Training"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats:
- **FP16** - Half precision 16-bit (Pascal generation+)
- **BF16** - Brain Float 16-bit (Ampere generation+)
- **FP8** - 8-bit floating point (Hopper generation+)
## FP16 Mixed Precision {#sec-fp16}
### Overview {#sec-fp16-overview}
FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16.
### Configuration {#sec-fp16-config}
```{.yaml}
fp16: true
```
### FP16 Considerations {#sec-fp16-considerations}
- May require gradient scaling to prevent underflow
- Less numerically stable than BF16
- Can cause training instability with some model architectures
- Consider using BF16 if your hardware supports it
## BF16 Mixed Precision {#sec-bf16}
### Overview {#sec-bf16-overview}
BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory.
### Configuration {#sec-bf16-config}
```{.yaml}
# Automatic BF16 detection (recommended)
bf16: auto
# Or explicitly enable
bf16: true
# For evaluation with BF16
bf16: full # Equivalent to bf16_full_eval in the HF trainer
```
## FP8 Mixed Precision {#sec-fp8}
::: {.callout-note}
FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO.
:::
### What is FP8? {#sec-fp8-overview}
FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy.
### Requirements {#sec-fp8-software}
- Hopper+ GPUs (H100/H200)
- PyTorch 2.7+ (+ compatible TorchAO version)
- CUDA 12.4+
### Configuration {#sec-fp8-config}
Add to your YAML config:
```{.yaml}
# Enable FP8 mixed precision
fp8: true
# Optional: Enable FP8 for FSDP all-gather operations
fp8_enable_fsdp_float8_all_gather: true
# Enable torch.compile (almost always necessary for FP8 speedups)
torch_compile: true
```
::: {.callout-important}
**torch.compile is critical for FP8 performance**
FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16.
:::
### Advanced FP8 Configs {#sec-fp8-advanced}
For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training:
```{.yaml}
fp8: true
fp8_enable_fsdp_float8_all_gather: true
torch_compile: true
# FSDP configuration
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
```
## Best Practices {#sec-best-practices}
### Choosing Precision Format {#sec-choosing-format}
- **Start with automatic detection**: `bf16: auto`
- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed
- **For Ampere (A100/RTX 30/40)**: Use BF16
- **For older Pascal/Turing GPUs**: Use FP16 with caution
- **For very old or unsupported GPUs**: Use FP32
### Validation and Testing {#sec-validation}
Always validate your mixed precision setup:
- **Start with a small dataset** to verify stability
- **Monitor loss curves** for irregularities
- **Compare with FP32 baseline** when possible
- **Test evaluation metrics** match expectations
### FP8 Particulars {#sec-fp8-details}
- Use cases
- Single GPU training
- Multi GPU training with FSDP2 or Deepspeed
- Speedups
- Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings
- Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks)
- Known issues:
- FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4))
- FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training
- Flash Attention 2 does not play nicely with `torch.compile`
See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model
For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd).

View File

@@ -98,8 +98,8 @@ fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
For example, if you were using the following FSDP1 config:
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
if you were using the following FSDP1 config:
```{.yaml}
fsdp_version: 1

View File

@@ -14,6 +14,7 @@ format:
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
@@ -110,6 +111,22 @@ base_model: google/gemma-3-4b-it
chat_template: gemma3
```
### Gemma-3n {#sec-gemma-3n}
::: {.callout-warning}
The model's initial loss and grad norm will be very high. We suspect this to be due to the Conv in the vision layers.
:::
::: {.callout-tip}
Please make sure to install `timm` via `pip3 install timm==1.0.17`
:::
```yaml
base_model: google/gemma-3n-E2B-it
chat_template: gemma3n
```
### Qwen2-VL {#sec-qwen2-vl}
```yaml
@@ -132,7 +149,9 @@ For multi-modal datasets, we adopt an extended `chat_template` format similar to
- A message is a list of `role` and `content`.
- `role` can be `system`, `user`, `assistant`, etc.
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
- `content` is a list of `type` and (`text`, `image`, `path`, `url`, `base64`, or `audio`).
### Image
::: {.callout-note}
For backwards compatibility:
@@ -141,15 +160,29 @@ For backwards compatibility:
- If `content` is a string, it will be converted to a list with `type` as `text`.
:::
::: {.callout-tip}
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
- `"path": "/path/to/image.jpg"`
- `"url": "https://example.com/image.jpg"`
- `"base64": "..."`
- `"image": PIL.Image`
### Audio
For audio loading, you can use the following keys within `content` alongside `"type": "audio"`:
- `"path": "/path/to/audio.mp3"`
- `"url": "https://example.com/audio.mp3"`
- `"audio": np.ndarray`
::: {.callout-tip}
You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::
### Example
Here is an example of a multi-modal dataset:
```json
[
@@ -178,3 +211,9 @@ Here is an example of a multi-modal dataset:
}
]
```
## FAQ
1. `PIL.UnidentifiedImageError: cannot identify image file ...`
`PIL` could not retrieve the file at `url` using `requests`. Please check for typo. One alternative reason is that the request is blocked by the server.

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -25,7 +25,7 @@ lora_target_linear: true
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -16,7 +16,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:

View File

@@ -19,7 +19,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -19,7 +19,7 @@ lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
adapter: lora

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -21,7 +21,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -26,5 +26,3 @@ timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -27,7 +27,7 @@ lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@50cef19\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@631d646\""
]
},
{

View File

@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -21,7 +21,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -12,7 +12,7 @@ output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -30,7 +30,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -25,7 +25,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -38,7 +38,7 @@ lora_target_modules:
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -31,7 +31,7 @@ lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -18,7 +18,7 @@ remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -35,7 +35,7 @@ lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -25,7 +25,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -0,0 +1,19 @@
# Gemma-3n
## Requirements
In addition to Axolotl's requirements, Gemma-3n requires
```
pip3 install timm
```
If you will load audio datasets, please also install
```
pip3 install librosa
```
## Usage
See example configs and the [multimodal doc](https://docs.axolotl.ai/docs/multimodal.html).

View File

@@ -0,0 +1,74 @@
base_model: google/gemma-3n-E2B-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
load_in_8bit: false
load_in_4bit: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
split: train[:1%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# lora_target_linear: # Does not work with gemma3n currently
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,80 @@
base_model: google/gemma-3n-E2B-it
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
# sample dataset below requires downloading audio/image in advance
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/African_elephant.jpg
# wget https://huggingface.co/datasets/Nanobit/text-vision-audio-2k-test/resolve/main/En-us-African_elephant.oga
datasets:
- path: Nanobit/text-vision-audio-2k-test
type: chat_template
data_files:
- dataset.jsonl
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,75 @@
base_model: google/gemma-3n-E2B-it
processor_type: AutoProcessor
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
# for use with fft to only train on language model layers
# unfrozen_parameters:
# - model.language_model.*
# - lm_head
# - embed_tokens
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3n
eot_tokens:
- <end_of_turn>
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
# flash_attention: true # Any attention impl does not work with gemma3n now
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -17,7 +17,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32

View File

@@ -23,7 +23,7 @@ save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16

View File

@@ -18,7 +18,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -14,7 +14,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:

View File

@@ -14,7 +14,7 @@ output_dir: ./outputs/lisa-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:

View File

@@ -14,7 +14,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -17,7 +17,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -18,7 +18,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 8
lora_alpha: 16

View File

@@ -15,8 +15,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/out
@@ -50,8 +49,8 @@ tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
# flash_attention: true # use for text-only mode
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1

View File

@@ -0,0 +1,76 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
output_dir: ./outputs/fp8_out/
sample_packing: true
pad_to_sequence_len: true
sequence_len: 512
flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
torch_compile: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
fp8: true
fp8_enable_fsdp_float8_all_gather: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_steps: 10
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: false
special_tokens:
pad_token: <|end_of_text|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -22,7 +22,7 @@ datasets:
output_dir: ./outputs/qat_out/
sample_packing: true
pad_to_sequence_len: true
sequence_len: 512
flex_attention: true

View File

@@ -26,7 +26,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -11,7 +11,7 @@ output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -37,7 +37,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -28,7 +28,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -49,7 +49,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -22,7 +22,7 @@ dataset_exact_deduplication: true
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -14,7 +14,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32

View File

@@ -15,7 +15,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32

View File

@@ -24,7 +24,7 @@ sample_packing: true
sample_packing_sequentially: true
curriculum_sampling: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -15,7 +15,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32

View File

@@ -18,7 +18,7 @@ output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:

View File

@@ -18,7 +18,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -18,7 +18,7 @@ adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 512
sample_packing: false
pad_to_sequence_len: true
lora_r: 8
lora_alpha: 16

View File

@@ -20,7 +20,7 @@ lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -16,7 +16,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:

View File

@@ -47,7 +47,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1

View File

@@ -48,7 +48,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -51,7 +51,7 @@ output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -46,7 +46,7 @@ output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 2
@@ -74,7 +74,7 @@ fsdp:
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
# fsdp_cpu_ram_efficient_loading: true # does not work with load_in_8bit/4bit
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT

View File

@@ -51,7 +51,7 @@ output_dir: ./outputs/out
sequence_len: 4096 # up to 8k will work on a single H100
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1

View File

@@ -11,8 +11,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -23,7 +23,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -22,7 +22,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16

View File

@@ -27,7 +27,7 @@ output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1

View File

@@ -14,7 +14,7 @@ output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
wandb_project:

Some files were not shown because too many files have changed in this diff Show More