Compare commits
1 Commits
upgrade-to
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0fccbadb79 |
43
.github/workflows/base.yml
vendored
43
.github/workflows/base.yml
vendored
@@ -25,18 +25,32 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
@@ -107,6 +121,20 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -121,13 +149,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
|
||||
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
@@ -12,9 +12,6 @@ jobs:
|
||||
build-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Quarto
|
||||
|
||||
64
.github/workflows/main.yml
vendored
64
.github/workflows/main.yml
vendored
@@ -15,6 +15,21 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -31,11 +46,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -82,6 +92,27 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -98,11 +129,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -144,18 +170,24 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
16
.github/workflows/multi-gpu-e2e.yml
vendored
16
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -19,9 +19,6 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
env:
|
||||
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
@@ -29,6 +26,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -39,7 +43,7 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.9.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
@@ -55,7 +59,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==1.0.2 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -68,4 +72,4 @@ jobs:
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run -m cicd.multigpu
|
||||
modal run cicd.multigpu
|
||||
|
||||
16
.github/workflows/nightlies.yml
vendored
16
.github/workflows/nightlies.yml
vendored
@@ -12,15 +12,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -64,15 +64,15 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
5
.github/workflows/preview-docs.yml
vendored
5
.github/workflows/preview-docs.yml
vendored
@@ -11,7 +11,6 @@ on:
|
||||
- '_quarto.yml'
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- src/axolotl/utils/schemas/**.py
|
||||
- .github/workflows/preview-docs.yml
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
@@ -28,10 +27,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
20
.github/workflows/tests-nightly.yml
vendored
20
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -99,17 +99,17 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
@@ -123,7 +123,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==1.0.2 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -148,10 +148,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 2
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==1.0.2 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
69
.github/workflows/tests.yml
vendored
69
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -66,13 +66,12 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
# - name: Restore Cache from S3
|
||||
# id: hf-cache-restore-s3
|
||||
# run: |
|
||||
# mkdir -p ~/.cache/huggingface/hub
|
||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
#
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -112,13 +111,10 @@ jobs:
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
df -h
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
@@ -126,9 +122,6 @@ jobs:
|
||||
df -h
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
@@ -145,7 +138,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -156,13 +149,12 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
# - name: Restore Cache from S3
|
||||
# id: hf-cache-restore-s3
|
||||
# run: |
|
||||
# mkdir -p ~/.cache/huggingface/hub
|
||||
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
#
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -204,13 +196,10 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
runs-on: ubuntu-latest
|
||||
@@ -271,7 +260,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==1.0.2 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -303,6 +292,18 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
# - cuda: 128
|
||||
# cuda_version: 12.8.1
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.7.1
|
||||
# num_gpus: 1
|
||||
# axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -313,7 +314,7 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.9.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -326,7 +327,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==1.0.2 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -353,10 +354,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -369,7 +370,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
pip install modal==1.0.2 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.10
|
||||
rev: v0.14.7
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.19.1
|
||||
rev: v1.19.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
14
README.md
14
README.md
@@ -29,15 +29,15 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
||||
- 2025/12: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3), [Trinity](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/trinity), and [Ministral3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/ministral3).
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
|
||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
|
||||
@@ -46,8 +46,8 @@
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
@@ -77,7 +77,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.8.0
|
||||
- PyTorch ≥2.7.1
|
||||
|
||||
### Google Colab
|
||||
|
||||
|
||||
44
_quarto.yml
44
_quarto.yml
@@ -1,8 +1,6 @@
|
||||
project:
|
||||
type: website
|
||||
pre-render:
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- docs/scripts/generate_examples_docs.py
|
||||
pre-render: docs/scripts/generate_config_docs.py
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
@@ -242,46 +240,6 @@ website:
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- section: "Model Guides"
|
||||
contents:
|
||||
- docs/models/kimi-linear.qmd
|
||||
- docs/models/plano.qmd
|
||||
- docs/models/mimo.qmd
|
||||
- docs/models/internvl3_5.qmd
|
||||
- docs/models/olmo3.qmd
|
||||
- docs/models/trinity.qmd
|
||||
- docs/models/arcee.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- section: "Ministral3"
|
||||
contents:
|
||||
- docs/models/ministral3.qmd
|
||||
- docs/models/ministral3/think.qmd
|
||||
- docs/models/ministral3/vision.qmd
|
||||
- section: "Magistral"
|
||||
contents:
|
||||
- docs/models/magistral.qmd
|
||||
- docs/models/magistral/think.qmd
|
||||
- docs/models/magistral/vision.qmd
|
||||
- docs/models/ministral.qmd
|
||||
- docs/models/mistral-small.qmd
|
||||
- docs/models/voxtral.qmd
|
||||
- docs/models/devstral.qmd
|
||||
- docs/models/llama-4.qmd
|
||||
- docs/models/llama-2.qmd
|
||||
- docs/models/qwen3-next.qmd
|
||||
- docs/models/qwen3.qmd
|
||||
- docs/models/gemma3n.qmd
|
||||
- docs/models/apertus.qmd
|
||||
- docs/models/gpt-oss.qmd
|
||||
- docs/models/seed-oss.qmd
|
||||
- docs/models/phi.qmd
|
||||
- docs/models/smolvlm2.qmd
|
||||
- docs/models/granite4.qmd
|
||||
- docs/models/LiquidAI.qmd
|
||||
- docs/models/hunyuan.qmd
|
||||
- docs/models/jamba.qmd
|
||||
- docs/models/orpheus.qmd
|
||||
|
||||
- docs/cli.qmd
|
||||
- docs/telemetry.qmd
|
||||
- docs/config-reference.qmd
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v --durations=10 -n2 --maxfail=4 \
|
||||
pytest -v --durations=10 -n2 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
|
||||
@@ -51,7 +51,7 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
|
||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -3,5 +3,3 @@ _site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
config-reference.qmd
|
||||
models/**/*.qmd
|
||||
models/**/*.html
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
---
|
||||
title: "Checkpoint Saving"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 2
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
|
||||
|
||||
## File-Based Checkpoint Trigger
|
||||
|
||||
### Configuration
|
||||
|
||||
Enable in your config:
|
||||
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 100 # Optional: check every N steps (default: 100)
|
||||
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `enabled`: `true` to enable (required)
|
||||
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
|
||||
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
|
||||
2. When detected, file is deleted and checkpoint is saved
|
||||
3. In distributed training, rank 0 broadcasts to synchronize all ranks
|
||||
|
||||
### Usage
|
||||
|
||||
**Command line:**
|
||||
```bash
|
||||
touch /path/to/output_dir/axolotl_checkpoint.save
|
||||
```
|
||||
|
||||
**Programmatic:**
|
||||
```python
|
||||
from pathlib import Path
|
||||
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
|
||||
```
|
||||
|
||||
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
|
||||
|
||||
**Custom filename:**
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
trigger_file_path: "my_trigger.save"
|
||||
```
|
||||
```bash
|
||||
touch /path/to/output_dir/my_trigger.save
|
||||
```
|
||||
|
||||
## Control+C (SIGINT) Checkpoint
|
||||
|
||||
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
|
||||
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
|
||||
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
output_dir: ./outputs/lora-out
|
||||
save_steps: 500 # Scheduled checkpoints
|
||||
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 50
|
||||
```
|
||||
|
||||
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).
|
||||
@@ -32,8 +32,11 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.8.0`
|
||||
- `main-base-py3.11-cu128-2.9.1`
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.0`
|
||||
- `main-base-py3.11-cu126-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -71,12 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.8.0`
|
||||
- `main-py3.11-cu128-2.9.1`
|
||||
- `main-py3.11-cu128-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.0`
|
||||
- `main-py3.11-cu126-2.6.0`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `0.12.0`
|
||||
- `0.10.1`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
|
||||
@@ -21,7 +21,6 @@ format:
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
- [Intern-VL](#sec-intern-vl)
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -203,16 +202,6 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
```
|
||||
|
||||
### Intern-VL {#sec-intern-vl}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install `timm` via `pip3 install timm==1.0.19`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: OpenGVLab/InternVL3_5-8B
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
examples:
|
||||
# December 2025
|
||||
- name: kimi-linear
|
||||
title: Kimi Linear
|
||||
- name: plano
|
||||
title: Plano Orchestrator
|
||||
- name: mimo
|
||||
title: MiMo
|
||||
- name: internvl3_5
|
||||
title: InternVL 3.5
|
||||
|
||||
# AllenAI
|
||||
- name: olmo3
|
||||
title: OLMo 3
|
||||
|
||||
# ArceeAI
|
||||
- name: trinity
|
||||
title: Trinity
|
||||
- name: arcee
|
||||
title: Arcee AFM
|
||||
|
||||
# MistralAI
|
||||
- name: ministral3/think
|
||||
title: Ministral 3 Thinking
|
||||
- name: ministral3/vision
|
||||
title: Ministral 3 Vision
|
||||
- name: magistral/think
|
||||
title: Magistral Thinking
|
||||
- name: magistral/vision
|
||||
title: Magistral Vision
|
||||
- name: ministral
|
||||
title: Ministral
|
||||
- name: mistral-small
|
||||
title: Mistral Small 3.1/3.2
|
||||
- name: voxtral
|
||||
title: Voxtral
|
||||
- name: devstral
|
||||
title: Devstral
|
||||
- name: mistral
|
||||
title: Mistral 7B
|
||||
|
||||
# Meta
|
||||
- name: llama-4
|
||||
title: Llama 4
|
||||
- name: llama-2
|
||||
title: Llama 2
|
||||
|
||||
# Alibaba
|
||||
- name: qwen3-next
|
||||
title: Qwen 3 Next
|
||||
- name: qwen3
|
||||
title: Qwen 3
|
||||
|
||||
# Google
|
||||
- name: gemma3n
|
||||
title: Gemma 3n
|
||||
|
||||
# Swiss AI
|
||||
- name: apertus
|
||||
title: Apertus
|
||||
|
||||
# GPT-OSS
|
||||
- name: gpt-oss
|
||||
title: GPT-OSS
|
||||
- name: seed-oss
|
||||
title: Seed-OSS
|
||||
|
||||
# Microsoft
|
||||
- name: phi
|
||||
title: Phi
|
||||
|
||||
# SmolVLM
|
||||
- name: smolvlm2
|
||||
title: SmolVLM 2
|
||||
|
||||
# IBM
|
||||
- name: granite4
|
||||
title: Granite 4
|
||||
|
||||
# LiquidAI
|
||||
- name: LiquidAI
|
||||
title: Liquid Foundation Models 2
|
||||
|
||||
# Other
|
||||
- name: hunyuan
|
||||
title: Hunyuan
|
||||
- name: jamba
|
||||
title: Jamba
|
||||
- name: orpheus
|
||||
title: Orpheus
|
||||
@@ -1,424 +0,0 @@
|
||||
"""
|
||||
auto generate example docs from allowlist
|
||||
"""
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
# Paths
|
||||
THIS = Path(__file__).resolve()
|
||||
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
|
||||
EXAMPLES_DIR = ROOT / "examples"
|
||||
OUTPUT_DIR = ROOT / "docs" / "models"
|
||||
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
|
||||
|
||||
|
||||
def slugify(name: str) -> str:
|
||||
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
|
||||
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
|
||||
s = re.sub(r"\s+", "-", s).strip("-").lower()
|
||||
return s or "example"
|
||||
|
||||
|
||||
def read_allowlist():
|
||||
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
items = data.get("examples", [])
|
||||
if not isinstance(items, list):
|
||||
raise ValueError("`examples` must be a list in examples-allowlist.yml")
|
||||
return items
|
||||
|
||||
|
||||
def find_readme(folder: Path) -> Path | None:
|
||||
for name in ("README.md", "Readme.md", "readme.md"):
|
||||
p = folder / name
|
||||
if p.exists():
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def remove_first_h1(md: str) -> tuple[str, str | None]:
|
||||
"""
|
||||
Remove the first H1 from markdown and return (modified_md, h1_title).
|
||||
The H1 is removed since we use the frontmatter title instead.
|
||||
"""
|
||||
lines = md.splitlines()
|
||||
result = []
|
||||
h1_title = None
|
||||
skipped_first = False
|
||||
|
||||
for line in lines:
|
||||
if not skipped_first and line.startswith("# "):
|
||||
h1_title = line[2:].strip()
|
||||
skipped_first = True
|
||||
continue
|
||||
result.append(line)
|
||||
|
||||
return "\n".join(result), h1_title
|
||||
|
||||
|
||||
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
|
||||
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
||||
|
||||
|
||||
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
|
||||
"""
|
||||
Copy local image assets referenced in markdown to
|
||||
docs/examples/assets/... and rewrite the links.
|
||||
"""
|
||||
dest_assets = dest_assets_root / "assets"
|
||||
|
||||
def repl(m):
|
||||
url = m.group(1).strip()
|
||||
if re.match(r"^(https?:)?//", url):
|
||||
return m.group(0) # leave remote URLs
|
||||
src_path = (src_dir / url).resolve()
|
||||
if not src_path.exists():
|
||||
return m.group(0) # leave as-is if not found
|
||||
rel = src_path.relative_to(src_dir)
|
||||
# Create a unique asset path based on source directory name
|
||||
asset_name = src_dir.name.replace("/", "-")
|
||||
dest_path = dest_assets / asset_name / rel
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
|
||||
return m.group(0).replace(url, new_rel)
|
||||
|
||||
return IMG_RE.sub(repl, md)
|
||||
|
||||
|
||||
def rewrite_readme_links(
|
||||
md: str,
|
||||
src_dir: Path,
|
||||
examples_dir: Path,
|
||||
parent_index_only: set,
|
||||
current_src_path: str,
|
||||
allowlist_entries: set,
|
||||
current_output_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Rewrite links between README.md files to point to the correct .qmd files.
|
||||
"""
|
||||
|
||||
def repl(m):
|
||||
text = m.group(1)
|
||||
url = m.group(2).strip()
|
||||
|
||||
# Skip remote URLs and anchor links
|
||||
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
|
||||
return m.group(0)
|
||||
|
||||
# Skip non-markdown files
|
||||
if not url.lower().endswith(".md"):
|
||||
return m.group(0)
|
||||
|
||||
# Resolve the target path
|
||||
try:
|
||||
target_path = (src_dir / url).resolve()
|
||||
|
||||
# Check if target is outside examples_dir
|
||||
try:
|
||||
rel_path = target_path.relative_to(examples_dir)
|
||||
except ValueError:
|
||||
# Target is outside examples_dir, leave as-is
|
||||
return m.group(0)
|
||||
|
||||
parts = list(rel_path.parts)
|
||||
|
||||
# Determine the output path for the target
|
||||
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
|
||||
# This is a README link
|
||||
if len(parts) == 1:
|
||||
# Link to root README -> index.qmd
|
||||
target_output = "index.qmd"
|
||||
elif len(parts) == 2:
|
||||
if parts[0] == ".":
|
||||
# Current directory README
|
||||
target_output = "index.qmd"
|
||||
else:
|
||||
# subdir/README.md
|
||||
parent_dir = parts[0]
|
||||
if parent_dir in parent_index_only:
|
||||
target_output = f"{parent_dir}/index.qmd"
|
||||
else:
|
||||
target_output = f"{parent_dir}.qmd"
|
||||
else:
|
||||
# Deeper nesting: parent/subdir/README.md
|
||||
# Build the full path like "parent/subdir"
|
||||
full_path = "/".join(parts[:-1]) # Remove README.md
|
||||
# Check if this exact path is in allowlist
|
||||
if full_path in allowlist_entries:
|
||||
# This is a sub-entry with its own entry -> use .qmd
|
||||
target_output = f"{full_path}.qmd"
|
||||
elif parts[0] == ".":
|
||||
# ./subdir/README.md -> check if subdir has own entry
|
||||
subdir = parts[1]
|
||||
if subdir in parent_index_only:
|
||||
target_output = f"{subdir}/index.qmd"
|
||||
else:
|
||||
target_output = f"{subdir}.qmd"
|
||||
else:
|
||||
# parent/subdir where parent doesn't have own entry
|
||||
target_output = f"{full_path}/index.qmd"
|
||||
else:
|
||||
# Regular .md file -> convert to .qmd, keep path structure
|
||||
target_output = "/".join(parts)[:-2] + "qmd"
|
||||
|
||||
# Compute relative path from current output file to target
|
||||
current_parts = current_output_path.split("/")
|
||||
target_parts = target_output.split("/")
|
||||
|
||||
# Special case: if current is a subdir file and target is a single-component file at root
|
||||
# Example: current="magistral/vision", target="magistral.qmd"
|
||||
if len(current_parts) > 1 and len(target_parts) == 1:
|
||||
# Current is in subdir, target is at root level
|
||||
# Go up to root: ../ for each level
|
||||
up_count = len(current_parts) - 1
|
||||
rel_parts = [".."] * up_count + [target_parts[0]]
|
||||
new_url = "/".join(rel_parts)
|
||||
else:
|
||||
# Find common prefix
|
||||
i = 0
|
||||
while (
|
||||
i < min(len(current_parts) - 1, len(target_parts))
|
||||
and current_parts[i] == target_parts[i]
|
||||
):
|
||||
i += 1
|
||||
|
||||
# Build relative path: go up (../) then down to target
|
||||
up_count = len(current_parts) - 1 - i
|
||||
rel_parts = [".."] * up_count + target_parts[i:]
|
||||
|
||||
if not rel_parts or rel_parts == [".."]:
|
||||
# Points to same directory or parent
|
||||
new_url = "/".join(rel_parts) if rel_parts else "."
|
||||
else:
|
||||
new_url = "/".join(rel_parts)
|
||||
|
||||
return f"[{text}]({new_url})"
|
||||
except (ValueError, IndexError):
|
||||
return m.group(0)
|
||||
|
||||
return LINK_RE.sub(repl, md)
|
||||
|
||||
|
||||
def write_qmd(out_path: Path, title: str, body_md: str):
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||
out_path.write_text(fm + body_md, encoding="utf-8")
|
||||
|
||||
|
||||
def update_quarto_yml(generated: list[tuple[str, str, str]]):
|
||||
"""
|
||||
Update _quarto.yml with the generated example files in the correct order.
|
||||
This keeps the sidebar in sync with the allowlist.
|
||||
|
||||
Model Guides is now nested under "Getting Started" section.
|
||||
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
|
||||
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
|
||||
"""
|
||||
quarto_yml = ROOT / "_quarto.yml"
|
||||
if not quarto_yml.exists():
|
||||
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
|
||||
return
|
||||
|
||||
content = quarto_yml.read_text(encoding="utf-8")
|
||||
|
||||
# First pass: find all parents that have sub-entries
|
||||
parents_with_subs = set()
|
||||
for path, _name, _title in generated:
|
||||
if "/" in path:
|
||||
parent = path.split("/")[0]
|
||||
parents_with_subs.add(parent)
|
||||
|
||||
# Build the YAML contents while preserving allowlist order
|
||||
lines = []
|
||||
processed_sections = set()
|
||||
|
||||
for path, _name, title in generated:
|
||||
# Check if this is a parent page that has sub-pages
|
||||
if path in parents_with_subs:
|
||||
# This is a parent page with sub-pages - create a nested section
|
||||
if path not in processed_sections:
|
||||
processed_sections.add(path)
|
||||
section_title = (
|
||||
title or path.replace("-", " ").replace("_", " ").title()
|
||||
)
|
||||
lines.append(f' - section: "{section_title}"')
|
||||
lines.append(" contents:")
|
||||
# Add the parent page first
|
||||
lines.append(f" - docs/models/{path}.qmd")
|
||||
# Then add all sub-pages
|
||||
for sub_path, _sub_name, _sub_title in generated:
|
||||
if "/" in sub_path and sub_path.split("/")[0] == path:
|
||||
lines.append(
|
||||
f" - docs/models/{sub_path}.qmd"
|
||||
)
|
||||
elif "/" not in path:
|
||||
# This is a flat item with no sub-pages
|
||||
# Skip if it was already included as part of a parent section
|
||||
if path not in processed_sections:
|
||||
lines.append(f" - docs/models/{path}.qmd")
|
||||
|
||||
yaml_content = "\n".join(lines) + "\n"
|
||||
|
||||
# Pattern to match only the Model Guides contents, stopping at the next item
|
||||
# in Getting Started (lines starting with 12 spaces: same level as the section)
|
||||
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
|
||||
|
||||
def replacement(match):
|
||||
prefix = match.group(1)
|
||||
return prefix + "\n" + yaml_content
|
||||
|
||||
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
||||
|
||||
if new_content != content:
|
||||
quarto_yml.write_text(new_content, encoding="utf-8")
|
||||
print(f"Updated {quarto_yml}")
|
||||
else:
|
||||
print(f"No changes needed for {quarto_yml}")
|
||||
|
||||
|
||||
def main():
|
||||
allow = read_allowlist()
|
||||
if not EXAMPLES_DIR.exists():
|
||||
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
|
||||
return
|
||||
|
||||
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# First pass: identify which parents have their own entry vs only sub-entries
|
||||
parent_entries = set() # Parents that have their own entry
|
||||
parent_with_subs = set() # Parents that have sub-entries
|
||||
allowlist_entries = set() # All entries in allowlist
|
||||
|
||||
for item in allow:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
else:
|
||||
name = item.get("name")
|
||||
|
||||
allowlist_entries.add(name)
|
||||
|
||||
if "/" in name:
|
||||
parent = name.split("/")[0]
|
||||
parent_with_subs.add(parent)
|
||||
else:
|
||||
parent_entries.add(name)
|
||||
|
||||
# Parents with subs that DON'T have their own entry -> use index.qmd
|
||||
parent_index_only = parent_with_subs - parent_entries
|
||||
|
||||
generated = []
|
||||
seen_dirs = set() # Track which parent directories we've created index for
|
||||
|
||||
for item in allow:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
title = None
|
||||
else:
|
||||
name = item.get("name")
|
||||
title = item.get("title")
|
||||
|
||||
if not name:
|
||||
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
src_dir = EXAMPLES_DIR / name
|
||||
if not src_dir.exists() or not src_dir.is_dir():
|
||||
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
readme = find_readme(src_dir)
|
||||
if not readme:
|
||||
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
md = readme.read_text(encoding="utf-8")
|
||||
|
||||
# Determine output path first (needed for link rewriting)
|
||||
parts = name.split("/")
|
||||
if len(parts) == 1:
|
||||
# Simple case: no subdirectory
|
||||
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
|
||||
sidebar_path = parts[0]
|
||||
else:
|
||||
# Has subdirectory: e.g., magistral/think
|
||||
parent = parts[0]
|
||||
child = "-".join(parts[1:]) # handle nested subdirs
|
||||
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
|
||||
sidebar_path = f"{parent}/{child}"
|
||||
|
||||
# Remove the first H1 (we use frontmatter title instead)
|
||||
md, _ = remove_first_h1(md)
|
||||
# Rewrite links between README files
|
||||
md = rewrite_readme_links(
|
||||
md,
|
||||
src_dir,
|
||||
EXAMPLES_DIR,
|
||||
parent_index_only,
|
||||
name,
|
||||
allowlist_entries,
|
||||
sidebar_path,
|
||||
)
|
||||
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
|
||||
|
||||
# Handle parent page generation for sub-entries
|
||||
if len(parts) > 1:
|
||||
# Has subdirectory: e.g., magistral/think
|
||||
parent = parts[0]
|
||||
|
||||
# Create parent.qmd if not already done and parent doesn't have own entry
|
||||
if parent not in seen_dirs and parent in parent_index_only:
|
||||
parent_readme = find_readme(EXAMPLES_DIR / parent)
|
||||
if parent_readme:
|
||||
parent_md = parent_readme.read_text(encoding="utf-8")
|
||||
parent_md, _ = remove_first_h1(parent_md)
|
||||
parent_md = rewrite_readme_links(
|
||||
parent_md,
|
||||
EXAMPLES_DIR / parent,
|
||||
EXAMPLES_DIR,
|
||||
parent_index_only,
|
||||
parent,
|
||||
allowlist_entries,
|
||||
parent,
|
||||
)
|
||||
parent_md = rewrite_and_copy_assets(
|
||||
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
|
||||
)
|
||||
parent_title = parent.replace("-", " ").replace("_", " ").title()
|
||||
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
|
||||
generated.append((parent, parent, parent_title))
|
||||
seen_dirs.add(parent)
|
||||
|
||||
if not title:
|
||||
title = name.replace("/", " ").replace("-", " ").title()
|
||||
|
||||
write_qmd(out_path, title, md)
|
||||
generated.append((sidebar_path, name, title))
|
||||
|
||||
# Index page - preserve allowlist order
|
||||
if generated:
|
||||
listing = "\n".join(
|
||||
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
|
||||
)
|
||||
index_md = (
|
||||
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
|
||||
+ listing
|
||||
+ "\n"
|
||||
)
|
||||
index_fm = (
|
||||
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||
)
|
||||
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
|
||||
|
||||
# Auto-update _quarto.yml to keep sidebar in sync
|
||||
update_quarto_yml(generated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -32,10 +32,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,10 +28,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -29,10 +29,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,10 +28,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,10 +41,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,10 +41,6 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
# Finetune OpenGV's InternVL with Axolotl
|
||||
|
||||
[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install `timm` for vision model support:
|
||||
|
||||
```bash
|
||||
pip install timm==1.0.19
|
||||
```
|
||||
|
||||
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [InternVL Paper](https://huggingface.co/papers/2508.18265)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,61 +0,0 @@
|
||||
base_model: OpenGVLab/InternVL3_5-8B-HF
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,47 +0,0 @@
|
||||
# Finetune MoonshotAI's Kimi Linear with Axolotl
|
||||
|
||||
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
|
||||
```
|
||||
|
||||
This config uses about 98.7GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning!
|
||||
|
||||
### TIPS
|
||||
|
||||
- Kimi Linear requires `trust_remote_code: true`.
|
||||
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
This is not yet compatible with MoE kernels from transformers v5.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
|
||||
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,81 +0,0 @@
|
||||
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.2
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -29,6 +29,7 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -5,7 +5,6 @@ This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mist
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -5,8 +5,7 @@ This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mist
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl from source (see [main README](../README.md))
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
## Getting started
|
||||
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Finetune Xiaomi's MiMo with Axolotl
|
||||
|
||||
[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/mimo/mimo-7b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MiMo Paper](https://arxiv.org/abs/2505.07608)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,67 +0,0 @@
|
||||
base_model: XiaomiMiMo/MiMo-7B-RL
|
||||
trust_remote_code: true
|
||||
revision_of_model: 6299b5a
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# CCE - N/A as of now
|
||||
# plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -5,7 +5,6 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -5,8 +5,7 @@ This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collectio
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl from source (see [main README](../README.md))
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
## Getting started
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Getting Started
|
||||
@@ -16,7 +16,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||
```
|
||||
|
||||
This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
|
||||
@@ -42,10 +42,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
# Finetune Katanemo's Plano-Orchestrator with Axolotl
|
||||
|
||||
[Plano-Orchestrator](https://huggingface.co/collections/katanemo/plano-orchestrator) is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/plano/plano-4b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Orchestration Prompt
|
||||
|
||||
Plano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the [official model card](https://huggingface.co/katanemo/Plano-Orchestrator-4B) for proper prompt formatting and the `ORCHESTRATION_PROMPT` template.
|
||||
|
||||
### Tips
|
||||
|
||||
- To use the larger [Plano-Orchestrator-30B-A3B](https://huggingface.co/katanemo/Plano-Orchestrator-30B-A3B) MoE model, simply change `base_model: katanemo/Plano-Orchestrator-30B-A3B` in the config and enable multi-GPU training if needed.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Plano GitHub](https://github.com/katanemo/plano)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,65 +0,0 @@
|
||||
base_model: katanemo/Plano-Orchestrator-4B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
chat_template: qwen3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
# Use random initialization for fair comparison
|
||||
reinit_weights: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
# Pretraining dataset
|
||||
pretraining_dataset:
|
||||
- path: allenai/c4
|
||||
name: en
|
||||
type: pretrain
|
||||
split: train
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/compare-adamw-pretrain
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project: dist_muon
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name: adamw
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
max_steps: 305
|
||||
|
||||
# AdamW optimizer settings (standard LR for AdamW)
|
||||
optimizer: adamw_torch_fused
|
||||
learning_rate: 0.0002
|
||||
weight_decay: 0.01
|
||||
lr_scheduler: cosine
|
||||
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_reshard_after_forward: true
|
||||
|
||||
special_tokens:
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
# Use random initialization for fair comparison
|
||||
reinit_weights: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
# Pretraining dataset
|
||||
pretraining_dataset:
|
||||
- path: allenai/c4
|
||||
name: en
|
||||
type: pretrain
|
||||
split: train
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/compare-muon-pretrain
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project: dist_muon
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name: muon
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
max_steps: 305
|
||||
|
||||
# Muon optimizer settings
|
||||
optimizer: muon
|
||||
learning_rate: 0.02
|
||||
weight_decay: 0.01
|
||||
lr_scheduler: cosine
|
||||
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_reshard_after_forward: true
|
||||
|
||||
special_tokens:
|
||||
@@ -29,10 +29,6 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
base_model: arcee-ai/Trinity-Nano-Preview
|
||||
trust_remote_code: true
|
||||
revision_of_model: 2ee94b0
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -14,21 +14,20 @@ huggingface_hub>=0.36.0
|
||||
peft>=0.18.0
|
||||
tokenizers>=0.22.1
|
||||
transformers==4.57.1
|
||||
accelerate==1.12.0
|
||||
datasets==4.4.2
|
||||
deepspeed>=0.18.3
|
||||
trl==0.25.1
|
||||
accelerate==1.11.0
|
||||
datasets==4.4.1
|
||||
deepspeed>=0.17.0
|
||||
trl==0.25.0
|
||||
hf_xet==1.2.0
|
||||
kernels==0.11.5
|
||||
trackio>=0.13.0
|
||||
typing-extensions>=4.15.0
|
||||
kernels>=0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio>=6.2.0,<7.0
|
||||
gradio==5.49.1
|
||||
|
||||
modal==1.3.0.post1
|
||||
modal==1.0.2
|
||||
pydantic>=2.10.6
|
||||
addict
|
||||
fire
|
||||
@@ -63,12 +62,13 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.15.0
|
||||
torchao==0.13.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-mit==0.0.6
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"'
|
||||
)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -156,7 +156,7 @@ extras_require = {
|
||||
"came_pytorch==0.1.3",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]>=2.52.1",
|
||||
"ray[train]",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.10.0",
|
||||
|
||||
@@ -24,7 +24,8 @@ if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
# 1. Define a base image for your training job
|
||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu128-2.9.1"
|
||||
# must use torch 2.7.0 for vllm
|
||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
||||
|
||||
# 2. Define the Runtime Environment for the Training Job
|
||||
# This includes start commands and environment variables.a
|
||||
|
||||
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu128-2.9.1"
|
||||
docker_tag = "main-py3.11-cu126-2.7.1"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
|
||||
@@ -26,7 +26,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trackio_ import setup_trackio_env_vars
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
@@ -228,7 +227,6 @@ def load_cfg(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": is_torch_bf16_gpu_available(),
|
||||
"fp8": compute_supports_fp8(),
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
@@ -247,7 +245,6 @@ def load_cfg(
|
||||
setup_wandb_env_vars(cfg)
|
||||
setup_mlflow_env_vars(cfg)
|
||||
setup_comet_env_vars(cfg)
|
||||
setup_trackio_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
@@ -262,11 +259,3 @@ def load_cfg(
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def compute_supports_fp8() -> bool:
|
||||
try:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability >= (9, 0)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
@@ -288,8 +288,8 @@ def do_inference_gradio(
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
|
||||
outputs=[masked_preview, html_out],
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
158
src/axolotl/core/attention/flex_block_mask.py
Normal file
158
src/axolotl/core/attention/flex_block_mask.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
monkeypatch for flex + packing
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import Cache, PretrainedConfig
|
||||
from transformers.masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_preprocess_mask_arguments,
|
||||
and_masks,
|
||||
causal_mask_function,
|
||||
or_masks,
|
||||
)
|
||||
from transformers.utils import is_torch_greater_or_equal
|
||||
|
||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
"""
|
||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
||||
to what is needed in the `modeling_xxx.py` files).
|
||||
|
||||
Args:
|
||||
config (`PretrainedConfig`):
|
||||
The model config.
|
||||
input_embeds (`torch.Tensor`):
|
||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
||||
batch size, query length and dtype.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
past_key_values (`Cache`, optional):
|
||||
The past key values, if we use a cache.
|
||||
or_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
and_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if (
|
||||
past_key_values
|
||||
and hasattr(past_key_values, "is_sliding")
|
||||
and False in past_key_values.is_sliding
|
||||
):
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
original_attention_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.clone().to(cache_position.device)
|
||||
)
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
)
|
||||
if early_exit:
|
||||
return attention_mask
|
||||
|
||||
batch_size, total_seq_len = cache_position.shape
|
||||
key_length = total_seq_len
|
||||
document_ids = torch.nn.functional.pad(
|
||||
original_attention_mask, value=0, pad=(0, key_length)
|
||||
)
|
||||
|
||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
||||
if attention_mask is not None:
|
||||
|
||||
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = (
|
||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
)
|
||||
final_mask = causal_mask_ & document_mask
|
||||
return final_mask
|
||||
|
||||
mask_factory_function = causal_doc_mask_mod
|
||||
else:
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
allow_is_causal_skip = (
|
||||
not past_key_values.is_compileable if past_key_values is not None else True
|
||||
)
|
||||
|
||||
# Allow slight deviations from causal mask
|
||||
if or_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
|
||||
# We now create the mask
|
||||
causal_mask = mask_interface(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
kv_length=kv_length,
|
||||
kv_offset=kv_offset,
|
||||
mask_function=mask_factory_function,
|
||||
attention_mask=attention_mask,
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
def patch_create_causal_mask(model_type):
|
||||
import transformers.masking_utils
|
||||
|
||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
||||
|
||||
if model_type:
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(module_path)
|
||||
module.create_causal_mask = create_causal_mask
|
||||
del sys.modules[module_path]
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -35,7 +35,6 @@ from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
is_trackio_available,
|
||||
)
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
@@ -148,14 +147,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_trackio and is_trackio_available():
|
||||
from axolotl.utils.callbacks.trackio_ import (
|
||||
SaveAxolotlConfigtoTrackioCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
@@ -290,22 +281,11 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||
|
||||
if self.cfg.optimizer == "muon":
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
|
||||
if device_mesh is not None:
|
||||
from axolotl.contribs.mit.muon.dist_muon import (
|
||||
DistMuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = DistMuonOptimizerFactory
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
else:
|
||||
from axolotl.contribs.mit.muon import (
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
from axolotl.contribs.mit.muon import (
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "dion":
|
||||
from axolotl.contribs.mit.dion import (
|
||||
@@ -443,8 +423,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
if self.cfg.use_trackio:
|
||||
report_to.append("trackio")
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
@@ -452,8 +430,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
elif self.cfg.use_trackio:
|
||||
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
|
||||
else:
|
||||
training_args_kwargs["run_name"] = None
|
||||
|
||||
|
||||
@@ -72,9 +72,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.include_tkps:
|
||||
callbacks.append(
|
||||
TokensPerSecondCallback(
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
|
||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||
)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
@@ -51,8 +49,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state."
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
@@ -352,33 +348,24 @@ class AxolotlTrainer(
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
|
||||
# track number of tokens for tokens per second calculation
|
||||
if self.args.include_tkps and model.training:
|
||||
if self.args.include_tkps:
|
||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||
trainable_tokens = (inputs[inputs_key] != -100).sum()
|
||||
total_tokens = inputs[inputs_key].numel()
|
||||
total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)
|
||||
|
||||
num_tokens = (inputs[inputs_key] != -100).sum()
|
||||
if is_distributed():
|
||||
torch.distributed.all_reduce(
|
||||
trainable_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
num_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
total_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
if hasattr(self.state, "num_tokens"):
|
||||
self.state.num_tokens = (
|
||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
||||
)
|
||||
else:
|
||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
||||
|
||||
if not hasattr(self.state, "tokens"):
|
||||
self.state.tokens = {
|
||||
"trainable": torch.zeros(1),
|
||||
"total": torch.zeros(1),
|
||||
}
|
||||
|
||||
# trainable tokens for throughput and total token slots for summaries
|
||||
self.state.tokens["trainable"] = (
|
||||
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
|
||||
)
|
||||
self.state.tokens["total"] = self.state.tokens["total"] + total_tokens.cpu()
|
||||
# Store per-step trainable tokens for throughput calculation
|
||||
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||
if hasattr(self.state, "total_tokens"):
|
||||
self.state.total_tokens += num_tokens
|
||||
else:
|
||||
self.state.total_tokens = num_tokens
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
@@ -616,7 +603,6 @@ class AxolotlTrainer(
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
|
||||
|
||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||
@@ -627,18 +613,7 @@ class AxolotlTrainer(
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max, sum]"
|
||||
)
|
||||
logs[key] = round(fn(values).item(), metric_ndigits)
|
||||
|
||||
if "loss" in logs:
|
||||
try:
|
||||
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
|
||||
except OverflowError:
|
||||
logs["ppl"] = float("inf")
|
||||
if "eval_loss" in logs:
|
||||
try:
|
||||
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
|
||||
except OverflowError:
|
||||
logs["eval_ppl"] = float("inf")
|
||||
logs[key] = round(fn(values).item(), 4)
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
@@ -650,14 +625,10 @@ class AxolotlTrainer(
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
if (
|
||||
self.args.include_tkps
|
||||
and train_eval == "train"
|
||||
and hasattr(self.state, "tokens")
|
||||
):
|
||||
if self.args.include_tkps and train_eval == "train":
|
||||
# each rank will log its own tokens per second
|
||||
# for logging_steps > 1 we obtain a moving average of this metric
|
||||
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||
logs["tokens_per_second_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
if (
|
||||
@@ -699,19 +670,6 @@ class AxolotlTrainer(
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save total_tokens state if tracking is enabled
|
||||
if self.args.include_tkps and hasattr(self.state, "tokens"):
|
||||
tokens_state = {
|
||||
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
|
||||
"trainable": int(
|
||||
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
|
||||
),
|
||||
}
|
||||
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
|
||||
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokens_state, f)
|
||||
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
|
||||
@@ -36,6 +36,4 @@ class DPOStrategy:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
if cfg.dpo_use_liger_kernel is not None:
|
||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -54,8 +54,6 @@ plugins:
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- internvl
|
||||
- kimi_linear
|
||||
- lfm2
|
||||
- lfm2_moe
|
||||
- lfm2_vl
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@f643b88"`'
|
||||
)
|
||||
|
||||
|
||||
@@ -96,11 +96,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
)
|
||||
|
||||
# The patch checks model_type internally
|
||||
|
||||
cce_patch(
|
||||
cfg.model_config_type,
|
||||
remote_model_id=cfg.base_model if cfg.trust_remote_code else None,
|
||||
)
|
||||
cce_patch(cfg.model_config_type)
|
||||
|
||||
def patch_llama_like(
|
||||
self,
|
||||
@@ -111,9 +107,7 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
"""
|
||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||
|
||||
def patch_generic(
|
||||
maybe_model, patch_options, model_type: str, remote_model_id: str | None
|
||||
):
|
||||
def patch_generic(maybe_model, patch_options, model_type: str):
|
||||
import cut_cross_entropy.transformers.llama
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
|
||||
if cfg.dense_mixer:
|
||||
if not importlib.util.find_spec("densemixer"):
|
||||
raise RuntimeError(
|
||||
"DenseMixer is not installed. Install it with `pip install densemixer`"
|
||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
||||
)
|
||||
|
||||
from densemixer.patching import (
|
||||
|
||||
@@ -26,48 +26,6 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
class PatchManager:
|
||||
"""Manages the application of patches during the model loading process."""
|
||||
|
||||
@staticmethod
|
||||
def apply_pre_config_load_patches(cfg: DictDefault):
|
||||
"""
|
||||
Apply patches that must be set up before config loading.
|
||||
This is for patches that intercept remote code loading from HuggingFace,
|
||||
which needs to be in place before AutoConfig.from_pretrained() is called.
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary with model and training settings.
|
||||
"""
|
||||
if (
|
||||
hasattr(cfg, "base_model_config")
|
||||
and cfg.base_model_config
|
||||
and "kimi-linear" in cfg.base_model_config.lower()
|
||||
):
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_config,
|
||||
)
|
||||
|
||||
patch_kimi_config()
|
||||
|
||||
@staticmethod
|
||||
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
|
||||
"""
|
||||
Apply patches that must be set up before tokenizer loading.
|
||||
This is for patches that intercept remote code loading from HuggingFace,
|
||||
which needs to be in place before AutoTokenizer.from_pretrained() is called.
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary with model and training settings.
|
||||
"""
|
||||
if (
|
||||
hasattr(cfg, "tokenizer_config")
|
||||
and cfg.tokenizer_config
|
||||
and "kimi-linear" in cfg.tokenizer_config.lower()
|
||||
):
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_tokenizer,
|
||||
)
|
||||
|
||||
patch_kimi_tokenizer()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
@@ -199,6 +157,12 @@ class PatchManager:
|
||||
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.core.attention.flex_block_mask import (
|
||||
patch_create_causal_mask,
|
||||
)
|
||||
|
||||
patch_create_causal_mask(self.cfg.model_config_type)
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
@@ -226,13 +190,6 @@ class PatchManager:
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.model_config_type == "kimi_linear":
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_model,
|
||||
)
|
||||
|
||||
patch_kimi_model()
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
@@ -124,11 +124,6 @@ def modify_tokenizer_files(
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
|
||||
# Apply patches that need to be in place before tokenizer loading
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
|
||||
PatchManager.apply_pre_tokenizer_load_patches(cfg)
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
@@ -79,11 +79,7 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
and hasattr(model_config, "vision_config")
|
||||
and hasattr(model_config.vision_config, "image_size")
|
||||
):
|
||||
image_size = model_config.vision_config.image_size
|
||||
if isinstance(image_size, list):
|
||||
cfg.image_size = tuple(image_size)
|
||||
else:
|
||||
cfg.image_size = image_size
|
||||
cfg.image_size = model_config.vision_config.image_size
|
||||
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
||||
|
||||
quant_config_exists = (
|
||||
|
||||
@@ -75,33 +75,3 @@ def patch_parallelism_config():
|
||||
|
||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||
|
||||
|
||||
def patch_prepare_cp():
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def patched_prepare_cp(self, *args):
|
||||
if self.parallelism_config.cp_backend == "deepspeed":
|
||||
return args
|
||||
|
||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(
|
||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
|
||||
return args
|
||||
|
||||
Accelerator._prepare_cp = patched_prepare_cp
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
"""
|
||||
Kimi-Linear configuration.
|
||||
|
||||
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/configuration_kimi.py
|
||||
Revision: 6e163f3
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class KimiLinearConfig(PretrainedConfig):
|
||||
model_type = "kimi_linear"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type="kimi_linear",
|
||||
vocab_size=163840,
|
||||
hidden_size=4096,
|
||||
head_dim=None,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
tie_word_embeddings=False,
|
||||
moe_intermediate_size: Optional[int] = None,
|
||||
moe_renormalize: bool = True,
|
||||
moe_router_activation_func: str = "sigmoid",
|
||||
num_experts: Optional[int] = None,
|
||||
num_experts_per_token: Optional[int] = None,
|
||||
num_shared_experts: int = 0,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
first_k_dense_replace: int = 0,
|
||||
moe_layer_freq: int = 1,
|
||||
use_grouped_topk: bool = True,
|
||||
num_expert_group: int = 1,
|
||||
topk_group: int = 1,
|
||||
q_lora_rank: Optional[int] = None,
|
||||
kv_lora_rank: Optional[int] = None,
|
||||
qk_nope_head_dim: Optional[int] = None,
|
||||
qk_rope_head_dim: Optional[int] = None,
|
||||
v_head_dim: Optional[int] = None,
|
||||
mla_use_nope: Optional[bool] = False,
|
||||
num_nextn_predict_layers: int = 0,
|
||||
linear_attn_config: Optional[dict] = None,
|
||||
router_aux_loss_coef: float = 0.01,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = (
|
||||
head_dim if head_dim is not None else hidden_size // num_attention_heads
|
||||
)
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.mla_use_nope = mla_use_nope
|
||||
# moe config
|
||||
self.num_experts = num_experts
|
||||
self.num_experts_per_token = num_experts_per_token
|
||||
self.moe_renormalize = moe_renormalize
|
||||
self.num_shared_experts = num_shared_experts
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.moe_router_activation_func = moe_router_activation_func
|
||||
assert self.moe_router_activation_func in ("softmax", "sigmoid")
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
|
||||
if linear_attn_config is not None:
|
||||
assert linear_attn_config["kda_layers"] is not None
|
||||
assert linear_attn_config["full_attn_layers"] is not None
|
||||
self.linear_attn_config = linear_attn_config
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_mla(self):
|
||||
return (
|
||||
self.q_lora_rank is not None
|
||||
or self.kv_lora_rank is not None
|
||||
or self.qk_nope_head_dim is not None
|
||||
or self.qk_rope_head_dim is not None
|
||||
or self.v_head_dim is not None
|
||||
or self.mla_use_nope is True
|
||||
)
|
||||
|
||||
@property
|
||||
def is_moe(self):
|
||||
return self.num_experts is not None
|
||||
|
||||
@property
|
||||
def is_linear_attn(self) -> bool:
|
||||
return not (
|
||||
self.linear_attn_config is None
|
||||
or (
|
||||
isinstance(self.linear_attn_config, dict)
|
||||
and self.linear_attn_config["kda_layers"] is not None
|
||||
and len(self.linear_attn_config["kda_layers"]) == 0
|
||||
)
|
||||
)
|
||||
|
||||
def is_kda_layer(self, layer_idx: int):
|
||||
return (
|
||||
self.linear_attn_config is not None
|
||||
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,85 +0,0 @@
|
||||
import importlib.resources
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
|
||||
|
||||
|
||||
def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
|
||||
"""
|
||||
Gets the absolute path to a patch file using importlib.resources.files.
|
||||
"""
|
||||
try:
|
||||
return importlib.resources.files(package_dot_path) / filename
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_local_module(module_name: str, filename: str):
|
||||
"""Helper to load a local module if not already loaded."""
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
|
||||
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)
|
||||
if patch_path and patch_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(module_name, patch_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
return None
|
||||
|
||||
|
||||
def _patch_get_class_in_module():
|
||||
"""
|
||||
Core patch function that hijacks Transformers' dynamic module loading.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import get_class_in_module
|
||||
|
||||
if hasattr(get_class_in_module, "_axolotl_patched"):
|
||||
return
|
||||
|
||||
original_get_class_in_module = get_class_in_module
|
||||
|
||||
# Mapping of module path patterns to (module_name, filename)
|
||||
KIMI_MODULE_MAP = {
|
||||
"configuration_kimi": ("configuration_kimi", "configuration_kimi.py"),
|
||||
"modeling_kimi": ("modeling_kimi", "modeling_kimi.py"),
|
||||
"tokenization_kimi": ("tokenization_kimi", "tokenization_kimi.py"),
|
||||
}
|
||||
|
||||
def patched_get_class_in_module(class_name, module_path, **kwargs):
|
||||
"""Patched version that returns our local modules instead of remote ones."""
|
||||
for pattern, (module_name, filename) in KIMI_MODULE_MAP.items():
|
||||
if pattern in module_path:
|
||||
module = _load_local_module(module_name, filename)
|
||||
if module:
|
||||
return getattr(module, class_name)
|
||||
break # Pattern matched but file not found, fall through
|
||||
|
||||
return original_get_class_in_module(class_name, module_path, **kwargs)
|
||||
|
||||
import transformers.dynamic_module_utils
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
|
||||
patched_get_class_in_module._axolotl_patched = True
|
||||
|
||||
|
||||
def patch_kimi():
|
||||
"""
|
||||
Apply all Kimi patches.
|
||||
Must be called BEFORE loading config/tokenizer/model.
|
||||
"""
|
||||
_patch_get_class_in_module()
|
||||
LOG.info("Kimi patches applied successfully!")
|
||||
|
||||
|
||||
# Keep these for backward compatibility if needed
|
||||
patch_kimi_config = patch_kimi
|
||||
patch_kimi_tokenizer = patch_kimi
|
||||
patch_kimi_model = patch_kimi
|
||||
@@ -1,357 +0,0 @@
|
||||
"""
|
||||
Adapted Kimi-Linear tokenizer to use proper template defaults and misc fixes.
|
||||
|
||||
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/tokenization_kimi.py
|
||||
Revision: 919416f
|
||||
"""
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
from tokenizers import AddedToken
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = getLogger(__name__)
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
|
||||
|
||||
|
||||
class TikTokenTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
The path to the Tiktoken model file.
|
||||
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
|
||||
The end of sequence token.
|
||||
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead. The second to last item in special_tokens.
|
||||
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
additional_special_tokens (list of `str`, *optional*):
|
||||
A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
|
||||
skipped when decoding if `skip_special_tokens` is set to `True`.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 256
|
||||
|
||||
pat_str = "|".join(
|
||||
[
|
||||
r"""[\p{Han}]+""",
|
||||
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||
r"""\p{N}{1,3}""",
|
||||
r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
|
||||
r"""\s*[\r\n]+""",
|
||||
r"""\s+(?!\S)""",
|
||||
r"""\s+""",
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
bos_token: Union[str, AddedToken] = "[BOS]", # nosec: B107
|
||||
eos_token: Union[str, AddedToken] = "[EOS]", # nosec: B107
|
||||
unk_token: Union[str, AddedToken, None] = None,
|
||||
pad_token: Union[str, AddedToken, None] = None,
|
||||
additional_special_tokens: List[str] = None,
|
||||
added_tokens_decoder: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert os.path.isfile(vocab_file), vocab_file
|
||||
|
||||
if additional_special_tokens is None:
|
||||
additional_special_tokens = [
|
||||
"<|im_end|>",
|
||||
"<|im_user|>",
|
||||
"<|im_assistant|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"[EOT]",
|
||||
"<|im_system|>",
|
||||
"<|im_middle|>",
|
||||
]
|
||||
|
||||
special_tokens_mapping = {
|
||||
i: added_tokens_decoder[i].content for i in added_tokens_decoder
|
||||
}
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
mergeable_ranks = load_tiktoken_bpe(vocab_file)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
self.special_tokens = {
|
||||
special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
|
||||
for i in range(
|
||||
num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
|
||||
)
|
||||
}
|
||||
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(vocab_file).name,
|
||||
pat_str=self.pat_str,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
logger.info(f"Reloaded tiktoken model from {vocab_file}")
|
||||
|
||||
self.n_words: int = self.model.n_vocab
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens[str(bos_token)]
|
||||
self.eos_id: int = self.special_tokens[str(eos_token)]
|
||||
logger.info(
|
||||
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||
)
|
||||
|
||||
self.pad_id: int = self.special_tokens[str(pad_token)]
|
||||
self.unk_id: int = self.special_tokens[str(unk_token)]
|
||||
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
|
||||
self.decoder = {}
|
||||
for i in range(self.n_words):
|
||||
# Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
|
||||
decoding = "".join(
|
||||
[
|
||||
self.byte_encoder[ord(char)]
|
||||
for char in self.model.decode_single_token_bytes(i).decode(
|
||||
"latin-1"
|
||||
)
|
||||
]
|
||||
)
|
||||
self.decoder[i] = decoding
|
||||
|
||||
self.encoder = {}
|
||||
for i in range(self.n_words):
|
||||
if i in self.decoder:
|
||||
self.encoder[self.decoder[i]] = i
|
||||
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
self.all_special_ids_set = set(self.all_special_ids)
|
||||
|
||||
def encode(
|
||||
self, text: str, allow_special_tokens: bool = True, **kwargs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text (str): The input string to be encoded.
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
"""
|
||||
# If there are other args, we should call super().encode because there are a lot of code
|
||||
# to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
|
||||
# NOTE: our encode method is not compatible with the super().encode method,
|
||||
# e.g. split_special_tokens' default is True in our encode method.
|
||||
if len(kwargs) > 0:
|
||||
# logger.warning(f"Calling super().encode with {kwargs}")
|
||||
return super().encode(text, **kwargs)
|
||||
|
||||
assert type(text) is str
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
texts = self.pre_tokenizer_process(text)
|
||||
|
||||
all_substrs = []
|
||||
for text in texts:
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
all_substrs.extend(substrs)
|
||||
|
||||
t: List[int] = []
|
||||
for substr in all_substrs:
|
||||
if allow_special_tokens:
|
||||
t.extend(
|
||||
# we should consider special token as a common token
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special="all",
|
||||
)
|
||||
)
|
||||
else:
|
||||
t.extend(
|
||||
# we should consider special token as a common token
|
||||
self.model.encode(
|
||||
substr,
|
||||
disallowed_special=(),
|
||||
)
|
||||
)
|
||||
|
||||
return t
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# If there are other args, we should call super().decode because there are a lot of code
|
||||
# to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
|
||||
if len(kwargs) > 0:
|
||||
return super().decode(token_ids, **kwargs)
|
||||
|
||||
if type(token_ids) is int:
|
||||
token_ids = [token_ids]
|
||||
|
||||
return self.model.decode(cast(List[int], token_ids))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(
|
||||
s: str, max_consecutive_slice_len: int
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
||||
|
||||
def pre_tokenizer_process(self, text: str) -> List[str]:
|
||||
"""
|
||||
pre-tokenizes the input text into a list of tokens.
|
||||
This method is used to split the input text into smaller chunks for internal processing.
|
||||
"""
|
||||
return [text]
|
||||
|
||||
""" ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.n_words
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self.encoder
|
||||
|
||||
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
return [self.decoder[t] for t in self.encode(text)]
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
return self.encoder.get(token, self.unk_id)
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
return self.decoder.get(index)
|
||||
|
||||
@staticmethod
|
||||
def clean_up_tokenization(out_string: str) -> str:
|
||||
return out_string
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
text = "".join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
"utf-8", "replace"
|
||||
)
|
||||
return text
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
raise ValueError(
|
||||
f"vocabulary path ({save_directory}) should be a directory"
|
||||
)
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "")
|
||||
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file
|
||||
) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation,
|
||||
tools: Optional[list[dict]] = None,
|
||||
tokenize: bool = True,
|
||||
add_generation_prompt: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
tools = deep_sort_dict(tools)
|
||||
return super().apply_chat_template(
|
||||
conversation,
|
||||
tools=tools,
|
||||
tokenize=tokenize,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def deep_sort_dict(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
|
||||
if isinstance(obj, list):
|
||||
return [deep_sort_dict(item) for item in obj]
|
||||
return obj
|
||||
@@ -8,7 +8,6 @@ from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_utils import load_image
|
||||
from transformers.models.internvl import InternVLProcessor
|
||||
from transformers.models.smolvlm import SmolVLMProcessor
|
||||
from transformers.models.voxtral import VoxtralProcessor
|
||||
|
||||
@@ -455,37 +454,6 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
|
||||
return labels
|
||||
|
||||
|
||||
class InternVLProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for InternVL"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
|
||||
if not hasattr(processor, "image_ids"):
|
||||
raise ValueError("'image_ids' missing from InternVL Processor.")
|
||||
|
||||
self.image_token_ids = processor.image_ids
|
||||
|
||||
def process_labels(self, input_ids):
|
||||
labels = input_ids.clone()
|
||||
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
for ids in self.image_token_ids:
|
||||
labels[labels == ids] = -100
|
||||
|
||||
# Note: Check if need to mask 'video_token' as it gets converted to
|
||||
# image patches during media processing
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -533,11 +501,6 @@ def get_processing_strategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, InternVLProcessor):
|
||||
return InternVLProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
|
||||
@@ -24,10 +24,6 @@ def is_opentelemetry_available():
|
||||
)
|
||||
|
||||
|
||||
def is_trackio_available():
|
||||
return importlib.util.find_spec("trackio") is not None
|
||||
|
||||
|
||||
def get_pytorch_version() -> tuple[int, int, int]:
|
||||
"""
|
||||
Get Pytorch version as a tuple of (major, minor, patch).
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""A callback for calculating tokens per second during training."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
@@ -12,52 +10,22 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state.json"
|
||||
|
||||
|
||||
class TokensPerSecondCallback(TrainerCallback):
|
||||
"""
|
||||
A callback to measure and log tokens per second during training.
|
||||
Also handles saving/restoring total_tokens state across checkpoint resumes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None
|
||||
):
|
||||
def __init__(self, tensor_parallel_size, context_parallel_size):
|
||||
super().__init__()
|
||||
self.step_time = 0.0
|
||||
self.start_time = 0.0
|
||||
self.non_data_parallel_size = 1
|
||||
self.resume_from_checkpoint = resume_from_checkpoint
|
||||
if tensor_parallel_size is not None:
|
||||
self.non_data_parallel_size *= tensor_parallel_size
|
||||
if context_parallel_size is not None:
|
||||
self.non_data_parallel_size *= context_parallel_size
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
"""Restore total_tokens state when resuming from checkpoint."""
|
||||
if not isinstance(self.resume_from_checkpoint, str):
|
||||
return
|
||||
tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE)
|
||||
if os.path.isfile(tokens_state_path):
|
||||
with open(tokens_state_path, "r", encoding="utf-8") as f:
|
||||
tokens_state = json.load(f)
|
||||
state.tokens = {
|
||||
"total": torch.tensor(tokens_state.get("total", 0)),
|
||||
"trainable": torch.tensor(tokens_state.get("trainable", 0)),
|
||||
}
|
||||
LOG.info(f"Restored total_tokens: {state.tokens['total']}")
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
@@ -65,8 +33,6 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
if not hasattr(state, "tokens"):
|
||||
state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)}
|
||||
self.start_time = time.perf_counter()
|
||||
state.last_tokens_per_second = torch.zeros(1)
|
||||
|
||||
@@ -77,10 +43,9 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
tokens = getattr(state, "tokens", None)
|
||||
if tokens and "trainable_tokens" in tokens:
|
||||
if hasattr(state, "num_tokens"):
|
||||
step_time = time.perf_counter() - self.start_time
|
||||
num_tokens_per_device = tokens["trainable_tokens"].clone()
|
||||
num_tokens_per_device = state.num_tokens.clone()
|
||||
# non data parallel groups have duplicated tokens, so we avoid double-counting
|
||||
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
|
||||
state.last_tokens_per_second = num_tokens_per_device / step_time
|
||||
@@ -95,15 +60,5 @@ class TokensPerSecondCallback(TrainerCallback):
|
||||
): # pylint: disable=unused-argument
|
||||
# after logging, clear the running metrics
|
||||
if hasattr(state, "last_tokens_per_second"):
|
||||
logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item()
|
||||
state.last_tokens_per_second.zero_()
|
||||
tokens = getattr(state, "tokens", None)
|
||||
# Clear per-step tokens after logging
|
||||
if tokens and "trainable_tokens" in tokens:
|
||||
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])
|
||||
|
||||
if tokens and "total" in tokens:
|
||||
logs["tokens/total"] = tokens["total"].item()
|
||||
|
||||
if tokens and "trainable" in tokens:
|
||||
logs["tokens/trainable"] = tokens["trainable"].item()
|
||||
state.num_tokens = torch.zeros(1)
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Trackio module for trainer callbacks"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trackio
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.environment import is_package_version_ge
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoTrackioCallback(TrainerCallback):
|
||||
"""Callback for trackio integration"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: "AxolotlTrainingArguments",
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
if not is_package_version_ge("trackio", "0.11.0"):
|
||||
LOG.warning(
|
||||
"Trackio version 0.11.0 or higher is required to save config files. "
|
||||
"Please upgrade trackio: pip install --upgrade trackio"
|
||||
)
|
||||
return control
|
||||
|
||||
trackio.save(self.axolotl_config_path)
|
||||
LOG.info("The Axolotl config has been saved to Trackio.")
|
||||
except (FileNotFoundError, ConnectionError, AttributeError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to Trackio: {err}")
|
||||
return control
|
||||
@@ -151,11 +151,6 @@ def normalize_config(cfg):
|
||||
if not cfg.base_model_config:
|
||||
cfg.base_model_config = cfg.base_model
|
||||
|
||||
# Apply pre-config load patches (e.g., for Kimi Linear remote code patching)
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
|
||||
PatchManager.apply_pre_config_load_patches(cfg)
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
cfg.tokenizer_config = (
|
||||
|
||||
@@ -180,18 +180,20 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
def handle_long_seq_in_dataset(
|
||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||
) -> Dataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to filter.
|
||||
sequence_len: Maximum length for sequences to keep
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
"""
|
||||
Remove or truncate sequences that exceed the configured maximum length from a dataset.
|
||||
|
||||
Parameters:
|
||||
dataset (Dataset): Dataset to process; if it lacks an "input_ids" column or is streaming, it is returned unchanged.
|
||||
sequence_len (int): Maximum allowed sequence length; sequences longer than this are either removed or truncated.
|
||||
cfg (DictDefault): Configuration object with keys:
|
||||
- excess_length_strategy: "drop", "truncate", or "raise" — determines how to handle overlong sequences.
|
||||
- min_sample_len: minimum allowed sequence length (used when truncating or dropping).
|
||||
- dataset_num_proc: number of processes to use for non-streaming datasets.
|
||||
- is_preprocess: when true, bypasses cached preprocessing during filtering.
|
||||
|
||||
Returns:
|
||||
Filtered dataset with long sequences handled according to the excess_length_strategy value:
|
||||
'drop' (default) excludes any sequence longer than sequence_len
|
||||
'truncate' truncates them down to sequence_len
|
||||
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
|
||||
Dataset: The input dataset with sequences longer than `sequence_len` removed or truncated according to `cfg`.
|
||||
"""
|
||||
if (
|
||||
hasattr(dataset, "column_names")
|
||||
@@ -234,12 +236,7 @@ def handle_long_seq_in_dataset(
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
action = (
|
||||
"Checking Sequence Lengths"
|
||||
if excess_length_strategy == "raise"
|
||||
else "Dropping Long Sequences"
|
||||
)
|
||||
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
|
||||
if excess_length_strategy == "truncate":
|
||||
process_fn = functools.partial(
|
||||
@@ -269,4 +266,4 @@ def handle_long_seq_in_dataset(
|
||||
)
|
||||
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||
|
||||
return dataset
|
||||
return dataset
|
||||
@@ -2,17 +2,9 @@
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
# Suppress noisy bitsandbytes warnings about dtype casting during quantization
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*MatMul8bitLt: inputs will be cast from.*",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
# Adapted from Accelerate
|
||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py
|
||||
|
||||
|
||||
@@ -9,10 +9,6 @@ from torchao.quantization import quantize_
|
||||
from torchao.quantization.qat import (
|
||||
QATConfig,
|
||||
)
|
||||
from torchao.quantization.qat import fake_quantizer
|
||||
from torchao.quantization.qat.fake_quantizer import (
|
||||
Int4WeightFakeQuantizer as AoInt4WeightFakeQuantizer,
|
||||
)
|
||||
from torchao.quantization.quant_api import (
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
@@ -21,27 +17,6 @@ from torchao.quantization.quant_api import (
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
|
||||
|
||||
class Int4WeightFakeQuantizer(AoInt4WeightFakeQuantizer):
|
||||
"""
|
||||
Adds 'enabled' attribute to Int4WeightFakeQuantizer (removed in torchao 0.15).
|
||||
Allows toggling fake quantization on/off for fake_quant_after_n_steps.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.enabled = True
|
||||
|
||||
def forward(self, w: torch.Tensor) -> torch.Tensor:
|
||||
if not self.enabled:
|
||||
return w
|
||||
return super().forward(w)
|
||||
|
||||
|
||||
# Replace the original Int4WeightFakeQuantizer in the fake_quantizer module
|
||||
# so that torchao's quantize_() function will use our version
|
||||
fake_quantizer.Int4WeightFakeQuantizer = Int4WeightFakeQuantizer
|
||||
|
||||
quantization_config_to_str = {
|
||||
Int8DynamicActivationInt4WeightConfig: "int8int4",
|
||||
Float8DynamicActivationFloat8WeightConfig: "fp8fp8",
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from accelerate.utils import is_fp8_available
|
||||
from annotated_types import MinLen
|
||||
from packaging import version
|
||||
from pydantic import (
|
||||
@@ -34,7 +33,6 @@ from axolotl.utils.schemas.integrations import (
|
||||
MLFlowConfig,
|
||||
OpenTelemetryConfig,
|
||||
RayConfig,
|
||||
TrackioConfig,
|
||||
WandbConfig,
|
||||
)
|
||||
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
|
||||
@@ -64,7 +62,6 @@ class AxolotlInputConfig(
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
TrackioConfig,
|
||||
OpenTelemetryConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
@@ -176,12 +173,6 @@ class AxolotlInputConfig(
|
||||
dpo_use_logits_to_keep: bool | None = None
|
||||
dpo_label_smoothing: float | None = None
|
||||
dpo_norm_loss: bool | None = None
|
||||
|
||||
dpo_use_liger_kernel: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Whether to use Liger kernel for DPO loss."},
|
||||
)
|
||||
|
||||
dpo_padding_free: bool | None = None
|
||||
dpo_generate_during_eval: bool | None = None
|
||||
|
||||
@@ -454,10 +445,10 @@ class AxolotlInputConfig(
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
},
|
||||
)
|
||||
excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field(
|
||||
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility."
|
||||
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
|
||||
},
|
||||
)
|
||||
eval_sequence_len: int | None = Field(
|
||||
@@ -1101,16 +1092,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fp8(self):
|
||||
if self.fp8 and not self.capabilities.fp8:
|
||||
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
|
||||
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
|
||||
raise ValueError(
|
||||
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_w_sdpa_bf16(cls, data):
|
||||
|
||||
@@ -200,23 +200,3 @@ class OpenTelemetryConfig(BaseModel):
|
||||
"description": "Port for the Prometheus metrics HTTP server"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class TrackioConfig(BaseModel):
|
||||
"""Trackio configuration subset"""
|
||||
|
||||
use_trackio: bool | None = None
|
||||
trackio_project_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Your trackio project name"},
|
||||
)
|
||||
trackio_run_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Set the name of your trackio run"},
|
||||
)
|
||||
trackio_space_id: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -751,19 +751,12 @@ class OptimizationValidationMixin:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_muon_deepspeed_fsdp(cls, data):
|
||||
if data.get("optimizer") == "muon":
|
||||
if data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
"Muon optimizer is currently incompatible with DeepSpeed"
|
||||
)
|
||||
if data.get("fsdp") or data.get("fsdp_config"):
|
||||
fsdp_version = data.get("fsdp_version")
|
||||
if fsdp_version is None:
|
||||
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
|
||||
if str(fsdp_version) != "2":
|
||||
raise ValueError(
|
||||
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
|
||||
)
|
||||
if data.get("optimizer") == "muon" and (
|
||||
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
|
||||
):
|
||||
raise ValueError(
|
||||
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -801,36 +794,6 @@ class OptimizationValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_cross_entropy_conflicts(cls, data):
|
||||
"""Check for mutual exclusivity between cross entropy patch options.
|
||||
|
||||
Only one of the following can be enabled at a time:
|
||||
- cut_cross_entropy (CutCrossEntropyPlugin)
|
||||
- chunked_cross_entropy
|
||||
- liger_cross_entropy (LigerPlugin)
|
||||
- liger_fused_linear_cross_entropy (LigerPlugin)
|
||||
"""
|
||||
ce_options = {
|
||||
"cut_cross_entropy": data.get("cut_cross_entropy"),
|
||||
"chunked_cross_entropy": data.get("chunked_cross_entropy"),
|
||||
"liger_cross_entropy": data.get("liger_cross_entropy"),
|
||||
"liger_fused_linear_cross_entropy": data.get(
|
||||
"liger_fused_linear_cross_entropy"
|
||||
),
|
||||
}
|
||||
|
||||
enabled_options = [k for k, v in ce_options.items() if v]
|
||||
|
||||
if len(enabled_options) > 1:
|
||||
raise ValueError(
|
||||
f"Only one cross entropy optimization can be enabled at a time. "
|
||||
f"Found {len(enabled_options)} enabled: {', '.join(enabled_options)}. "
|
||||
"Please disable all but one."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version(cls, data):
|
||||
@@ -877,6 +840,40 @@ class OptimizationValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
fsdp_config = data.get("fsdp_config") or {}
|
||||
if fsdp_config and fsdp_config.get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
should_fix = False
|
||||
for key, _ in fsdp_config.items():
|
||||
if key.startswith("fsdp_"):
|
||||
should_fix = True
|
||||
LOG.warning_once(
|
||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||
)
|
||||
if should_fix:
|
||||
update_fsdp_config = {}
|
||||
for key, value in fsdp_config.items():
|
||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||
else:
|
||||
update_fsdp_config[key] = value
|
||||
data["fsdp_config"] = update_fsdp_config
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp_offload_w_8bit_optimizer(self):
|
||||
if (
|
||||
@@ -978,40 +975,6 @@ class OptimizationValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
fsdp_config = data.get("fsdp_config") or {}
|
||||
if fsdp_config and fsdp_config.get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
should_fix = False
|
||||
for key, _ in fsdp_config.items():
|
||||
if key.startswith("fsdp_"):
|
||||
should_fix = True
|
||||
LOG.warning_once(
|
||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||
)
|
||||
if should_fix:
|
||||
update_fsdp_config = {}
|
||||
for key, value in fsdp_config.items():
|
||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||
else:
|
||||
update_fsdp_config[key] = value
|
||||
data["fsdp_config"] = update_fsdp_config
|
||||
return data
|
||||
|
||||
|
||||
class SystemValidationMixin:
|
||||
"""Validation methods related to system and hardware configuration."""
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
"""Module for trackio utilities"""
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def setup_trackio_env_vars(cfg: DictDefault):
|
||||
for key in cfg.keys():
|
||||
if key.startswith("trackio_"):
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value and isinstance(value, str) and len(value) > 0:
|
||||
os.environ[key.upper()] = value
|
||||
|
||||
if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:
|
||||
cfg.use_trackio = True
|
||||
@@ -201,19 +201,33 @@ def add_pose_position_ids(
|
||||
|
||||
|
||||
def add_length(sample):
|
||||
"""
|
||||
Set the "length" field on a sample to the number of input tokens.
|
||||
|
||||
Parameters:
|
||||
sample (Mapping-like): A sample containing an "input_ids" sequence.
|
||||
|
||||
Returns:
|
||||
sample (dict-like): The same sample with "length" set to len(sample["input_ids"]).
|
||||
"""
|
||||
sample["length"] = len(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
||||
"""
|
||||
Drop samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
|
||||
If raise_on_drop is set, the code raises a ValueError if a sample is
|
||||
encountered that is too long and would have been dropped.
|
||||
Return whether a sample (single or batched) should be kept based on sequence length constraints.
|
||||
|
||||
Determines if each sequence's length falls within [min_sequence_len, sequence_len]. Supports a single example (list[int]) or a batch (list[list[int]]). If the sample's "input_ids" is empty, the sample is treated as dropped. When raise_on_drop is True, encountering any sequence longer than sequence_len raises a ValueError.
|
||||
|
||||
Parameters:
|
||||
sample (dict): A mapping containing "input_ids" with either a single sequence or a batch of sequences.
|
||||
sequence_len (int): Maximum allowed sequence length (inclusive).
|
||||
min_sequence_len (int): Minimum allowed sequence length (inclusive).
|
||||
raise_on_drop (bool): If True, raise ValueError when a sequence exceeds sequence_len.
|
||||
|
||||
Returns:
|
||||
bool or list[bool]: For a single example, returns True if its length is within the bounds, False otherwise. For a batch, returns a list of booleans indicating which sequences should be kept.
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
@@ -645,9 +659,6 @@ def setup_parallelism_envs(cfg):
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
|
||||
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
|
||||
from axolotl.monkeypatch.accelerate.parallelism_config import patch_prepare_cp
|
||||
|
||||
patch_prepare_cp()
|
||||
if set_accelerate_parallelism_config:
|
||||
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
||||
|
||||
@@ -729,4 +740,4 @@ def setup_trainer(
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
|
||||
return trainer_builder.build(total_num_steps)
|
||||
return trainer_builder.build(total_num_steps)
|
||||
@@ -62,7 +62,7 @@ def snapshot_download_w_retry(*args, **kwargs):
|
||||
"""
|
||||
with hf_offline_context(True):
|
||||
try:
|
||||
return snapshot_download(*args, local_files_only=True, **kwargs)
|
||||
return snapshot_download(*args, **kwargs)
|
||||
except LocalEntryNotFoundError:
|
||||
pass
|
||||
with hf_offline_context(False):
|
||||
|
||||
@@ -474,8 +474,10 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import MuonOptimizerFactory
|
||||
from axolotl.contribs.mit.muon.muon import Muon
|
||||
from axolotl.contribs.mit.muon import (
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||
assert optimizer_cls is MuonOptimizerFactory
|
||||
@@ -554,8 +556,10 @@ class TestHFCausalTrainerBuilder:
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import MuonOptimizerFactory
|
||||
from axolotl.contribs.mit.muon.muon import Muon
|
||||
from axolotl.contribs.mit.muon import (
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||
assert optimizer_cls is MuonOptimizerFactory
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
output_path.glob("*.safetensors")
|
||||
)
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert len(checkpoint_files) > 0, (
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
|
||||
|
||||
class TestDistMuon:
|
||||
"""Test class for DistMuon optimizer with FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_fft_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.02,
|
||||
"optimizer": "muon",
|
||||
"weight_decay": 0.01,
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.02,
|
||||
"optimizer": "muon",
|
||||
"weight_decay": 0.01,
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for resuming training
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
@@ -10,7 +9,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -60,7 +58,6 @@ class TestResumeLlama:
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
"include_tkps": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -71,19 +68,8 @@ class TestResumeLlama:
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
initial_total_num_tokens = cfg.total_num_tokens
|
||||
assert initial_total_num_tokens is not None, (
|
||||
"total_num_tokens should be calculated during load_datasets"
|
||||
)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
checkpoint_path = f"{temp_dir}/checkpoint-9"
|
||||
tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE)
|
||||
assert os.path.isfile(tokens_state_path), (
|
||||
f"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}"
|
||||
)
|
||||
|
||||
resume_cfg = cfg | DictDefault(
|
||||
{
|
||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
|
||||
@@ -91,24 +77,7 @@ class TestResumeLlama:
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
|
||||
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||
f"total_num_tokens should be preserved on resume. "
|
||||
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||
)
|
||||
|
||||
resume_dataset_meta = load_datasets(cfg=resume_cfg)
|
||||
|
||||
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||
f"total_num_tokens should not be recalculated when resuming. "
|
||||
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||
)
|
||||
|
||||
train(cfg=resume_cfg, dataset_meta=resume_dataset_meta)
|
||||
|
||||
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
|
||||
f"total_num_tokens should remain unchanged after resume training. "
|
||||
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
|
||||
)
|
||||
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
||||
|
||||
@@ -6,6 +6,8 @@ import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
from huggingface_hub.utils import reset_sessions
|
||||
|
||||
|
||||
def reload_modules(hf_hub_offline):
|
||||
# Force reload of the modules that check this variable
|
||||
@@ -19,6 +21,7 @@ def reload_modules(hf_hub_offline):
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
||||
importlib.reload(datasets.config)
|
||||
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
|
||||
reset_sessions()
|
||||
|
||||
|
||||
def enable_hf_offline(test_func):
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from axolotl.utils.data import encode_streaming, md5
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@@ -64,42 +63,6 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||
)
|
||||
|
||||
def test_excess_length_strategy(self):
|
||||
"""Test that excess_length_strategy results in a value error when set to 'raise'."""
|
||||
|
||||
# -- single sequence --
|
||||
# This should work
|
||||
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
|
||||
drop_long_seq(data, 32, raise_on_drop=True)
|
||||
|
||||
# This should return True, since data fits
|
||||
dropped = drop_long_seq(data, 32)
|
||||
self.assertTrue(dropped)
|
||||
|
||||
# This should raise
|
||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||
|
||||
# This should return False, since data doesn't fit
|
||||
dropped = drop_long_seq(data, 15)
|
||||
self.assertFalse(dropped)
|
||||
|
||||
# -- batch sequence --
|
||||
# This should work
|
||||
data = {
|
||||
"input_ids": [
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
]
|
||||
}
|
||||
drop_long_seq(data, 32, raise_on_drop=True)
|
||||
|
||||
# This should raise
|
||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||
|
||||
# This should keep the first but drop the second entry
|
||||
dropped = drop_long_seq(data, 15)
|
||||
self.assertEqual(dropped, [True, False])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -13,9 +13,7 @@ from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import (
|
||||
_load_tokenized_prepared_datasets,
|
||||
)
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.constants import (
|
||||
|
||||
@@ -363,5 +363,5 @@ class TestOptimizerValidation(BaseValidation):
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"):
|
||||
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -123,17 +123,6 @@ class TestFSDPValidation:
|
||||
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
|
||||
assert cfg.fsdp_config.reshard_after_forward is True
|
||||
|
||||
def test_muon_fsdp1_rejected(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
optimizer="muon",
|
||||
fsdp_version=1,
|
||||
fsdp_config={"reshard_after_forward": True},
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="Muon optimizer is only compatible with FSDP2"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rl",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user