Compare commits
36 Commits
coderabbit
...
upgrade-to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a08e4117a | ||
|
|
b26ba3a5cb | ||
|
|
afe18ace35 | ||
|
|
2b199f9915 | ||
|
|
e73dab6df9 | ||
|
|
f45a97a9ff | ||
|
|
11c0b5b256 | ||
|
|
66a3de3629 | ||
|
|
a6080df73c | ||
|
|
4f5e8a328a | ||
|
|
418933f0d1 | ||
|
|
372f664c63 | ||
|
|
97f1b1758d | ||
|
|
f2155eaf79 | ||
|
|
92ee4256f7 | ||
|
|
efeb5a4e41 | ||
|
|
faaff6c792 | ||
|
|
43cef27458 | ||
|
|
07c41a6c2a | ||
|
|
bbd3486f57 | ||
|
|
3750d7dd64 | ||
|
|
2197b0bf89 | ||
|
|
3e51a680c2 | ||
|
|
2cf254b4af | ||
|
|
83d4d97dcc | ||
|
|
a1d07f42e4 | ||
|
|
2a664dc8ad | ||
|
|
4ac78aa562 | ||
|
|
b3f4aa149f | ||
|
|
75b20fb66f | ||
|
|
5992e607a2 | ||
|
|
2b66ee189c | ||
|
|
86d8cca149 | ||
|
|
4a0f98e612 | ||
|
|
c6ddcdd06a | ||
|
|
7fb6a947d9 |
49
.github/workflows/base.yml
vendored
49
.github/workflows/base.yml
vendored
@@ -25,27 +25,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -53,6 +32,13 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -121,20 +107,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -149,6 +121,13 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
|
||||
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
@@ -12,6 +12,9 @@ jobs:
|
||||
build-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Quarto
|
||||
|
||||
64
.github/workflows/main.yml
vendored
64
.github/workflows/main.yml
vendored
@@ -15,21 +15,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -46,6 +31,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -92,27 +82,6 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -129,6 +98,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
# - cuda: 130
|
||||
# cuda_version: 13.0.0
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -170,24 +144,18 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
16
.github/workflows/multi-gpu-e2e.yml
vendored
16
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -19,6 +19,9 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
env:
|
||||
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
@@ -26,13 +29,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -43,7 +39,7 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
@@ -59,7 +55,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -72,4 +68,4 @@ jobs:
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.multigpu
|
||||
modal run -m cicd.multigpu
|
||||
|
||||
20
.github/workflows/nightlies.yml
vendored
20
.github/workflows/nightlies.yml
vendored
@@ -12,16 +12,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -64,16 +64,16 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
5
.github/workflows/preview-docs.yml
vendored
5
.github/workflows/preview-docs.yml
vendored
@@ -11,6 +11,7 @@ on:
|
||||
- '_quarto.yml'
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- src/axolotl/utils/schemas/**.py
|
||||
- .github/workflows/preview-docs.yml
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
@@ -27,6 +28,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
steps:
|
||||
- name: cleanup node
|
||||
run: |
|
||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
|
||||
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
|
||||
20
.github/workflows/tests-nightly.yml
vendored
20
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -99,17 +99,17 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
@@ -123,7 +123,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -148,10 +148,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 2
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
@@ -165,7 +165,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
59
.github/workflows/tests.yml
vendored
59
.github/workflows/tests.yml
vendored
@@ -55,7 +55,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -69,8 +69,9 @@ jobs:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -111,13 +112,23 @@ jobs:
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
df -h
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
@@ -134,7 +145,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.7.1", "2.8.0", "2.9.0"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -148,8 +159,9 @@ jobs:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -188,14 +200,17 @@ jobs:
|
||||
axolotl --help
|
||||
|
||||
- name: Show HF cache
|
||||
run: huggingface-cli scan-cache
|
||||
run: hf cache scan
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: Show HF cache
|
||||
run: hf cache scan
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
runs-on: ubuntu-latest
|
||||
@@ -256,7 +271,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -288,18 +303,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
# - cuda: 128
|
||||
# cuda_version: 12.8.1
|
||||
# python_version: "3.11"
|
||||
# pytorch: 2.7.1
|
||||
# num_gpus: 1
|
||||
# axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -310,7 +313,7 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.0
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -323,7 +326,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -350,10 +353,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -366,7 +369,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==1.3.0.post1 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
@@ -11,13 +11,13 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.3
|
||||
rev: v0.14.10
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
rev: v1.19.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.6
|
||||
rev: 1.9.2
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -10,6 +10,7 @@ ARG BASE_VOLUME="/runpod-volume"
|
||||
ENV BASE_VOLUME=$BASE_VOLUME
|
||||
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
|
||||
COPY .runpod/src /src
|
||||
|
||||
14
README.md
14
README.md
@@ -29,15 +29,15 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/11: Axolotl now includes support for [Olmo3](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/olmo3).
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/qwen3-next), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3), [Granite 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/granite4), [HunYuan](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/hunyuan), [Magistral 2509](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral#vision), [Apertus](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/apertus), and [Seed-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/seed-oss).
|
||||
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
|
||||
@@ -46,8 +46,8 @@
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
@@ -77,7 +77,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.7.1
|
||||
- PyTorch ≥2.8.0
|
||||
|
||||
### Google Colab
|
||||
|
||||
|
||||
44
_quarto.yml
44
_quarto.yml
@@ -1,6 +1,8 @@
|
||||
project:
|
||||
type: website
|
||||
pre-render: docs/scripts/generate_config_docs.py
|
||||
pre-render:
|
||||
- docs/scripts/generate_config_docs.py
|
||||
- docs/scripts/generate_examples_docs.py
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
@@ -240,6 +242,46 @@ website:
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- section: "Model Guides"
|
||||
contents:
|
||||
- docs/models/kimi-linear.qmd
|
||||
- docs/models/plano.qmd
|
||||
- docs/models/mimo.qmd
|
||||
- docs/models/internvl3_5.qmd
|
||||
- docs/models/olmo3.qmd
|
||||
- docs/models/trinity.qmd
|
||||
- docs/models/arcee.qmd
|
||||
- docs/models/mistral.qmd
|
||||
- section: "Ministral3"
|
||||
contents:
|
||||
- docs/models/ministral3.qmd
|
||||
- docs/models/ministral3/think.qmd
|
||||
- docs/models/ministral3/vision.qmd
|
||||
- section: "Magistral"
|
||||
contents:
|
||||
- docs/models/magistral.qmd
|
||||
- docs/models/magistral/think.qmd
|
||||
- docs/models/magistral/vision.qmd
|
||||
- docs/models/ministral.qmd
|
||||
- docs/models/mistral-small.qmd
|
||||
- docs/models/voxtral.qmd
|
||||
- docs/models/devstral.qmd
|
||||
- docs/models/llama-4.qmd
|
||||
- docs/models/llama-2.qmd
|
||||
- docs/models/qwen3-next.qmd
|
||||
- docs/models/qwen3.qmd
|
||||
- docs/models/gemma3n.qmd
|
||||
- docs/models/apertus.qmd
|
||||
- docs/models/gpt-oss.qmd
|
||||
- docs/models/seed-oss.qmd
|
||||
- docs/models/phi.qmd
|
||||
- docs/models/smolvlm2.qmd
|
||||
- docs/models/granite4.qmd
|
||||
- docs/models/LiquidAI.qmd
|
||||
- docs/models/hunyuan.qmd
|
||||
- docs/models/jamba.qmd
|
||||
- docs/models/orpheus.qmd
|
||||
|
||||
- docs/cli.qmd
|
||||
- docs/telemetry.qmd
|
||||
- docs/config-reference.qmd
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v --durations=10 -n2 \
|
||||
pytest -v --durations=10 -n2 --maxfail=4 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
|
||||
@@ -51,7 +51,7 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.9.1" ] && [ "$CUDA" = "128" ] ; then \
|
||||
RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \
|
||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
|
||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -3,3 +3,5 @@ _site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
config-reference.qmd
|
||||
models/**/*.qmd
|
||||
models/**/*.html
|
||||
|
||||
86
docs/checkpoint_saving.qmd
Normal file
86
docs/checkpoint_saving.qmd
Normal file
@@ -0,0 +1,86 @@
|
||||
---
|
||||
title: "Checkpoint Saving"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
toc-depth: 2
|
||||
number-sections: true
|
||||
execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
|
||||
|
||||
## File-Based Checkpoint Trigger
|
||||
|
||||
### Configuration
|
||||
|
||||
Enable in your config:
|
||||
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 100 # Optional: check every N steps (default: 100)
|
||||
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `enabled`: `true` to enable (required)
|
||||
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
|
||||
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
|
||||
2. When detected, file is deleted and checkpoint is saved
|
||||
3. In distributed training, rank 0 broadcasts to synchronize all ranks
|
||||
|
||||
### Usage
|
||||
|
||||
**Command line:**
|
||||
```bash
|
||||
touch /path/to/output_dir/axolotl_checkpoint.save
|
||||
```
|
||||
|
||||
**Programmatic:**
|
||||
```python
|
||||
from pathlib import Path
|
||||
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
|
||||
```
|
||||
|
||||
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
|
||||
|
||||
**Custom filename:**
|
||||
```yaml
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
trigger_file_path: "my_trigger.save"
|
||||
```
|
||||
```bash
|
||||
touch /path/to/output_dir/my_trigger.save
|
||||
```
|
||||
|
||||
## Control+C (SIGINT) Checkpoint
|
||||
|
||||
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
|
||||
|
||||
## Best Practices
|
||||
|
||||
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
|
||||
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
|
||||
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
output_dir: ./outputs/lora-out
|
||||
save_steps: 500 # Scheduled checkpoints
|
||||
|
||||
dynamic_checkpoint:
|
||||
enabled: true
|
||||
check_interval: 50
|
||||
```
|
||||
|
||||
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).
|
||||
@@ -32,11 +32,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.0`
|
||||
- `main-base-py3.11-cu126-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu128-2.8.0`
|
||||
- `main-base-py3.11-cu128-2.9.1`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -74,15 +71,12 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu128-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.0`
|
||||
- `main-py3.11-cu126-2.6.0`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu128-2.8.0`
|
||||
- `main-py3.11-cu128-2.9.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `0.10.1`
|
||||
- `0.12.0`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
|
||||
@@ -21,6 +21,7 @@ format:
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
- [Intern-VL](#sec-intern-vl)
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -202,6 +203,16 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
```
|
||||
|
||||
### Intern-VL {#sec-intern-vl}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install `timm` via `pip3 install timm==1.0.19`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: OpenGVLab/InternVL3_5-8B
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
|
||||
90
docs/scripts/examples-allowlist.yml
Normal file
90
docs/scripts/examples-allowlist.yml
Normal file
@@ -0,0 +1,90 @@
|
||||
examples:
|
||||
# December 2025
|
||||
- name: kimi-linear
|
||||
title: Kimi Linear
|
||||
- name: plano
|
||||
title: Plano Orchestrator
|
||||
- name: mimo
|
||||
title: MiMo
|
||||
- name: internvl3_5
|
||||
title: InternVL 3.5
|
||||
|
||||
# AllenAI
|
||||
- name: olmo3
|
||||
title: OLMo 3
|
||||
|
||||
# ArceeAI
|
||||
- name: trinity
|
||||
title: Trinity
|
||||
- name: arcee
|
||||
title: Arcee AFM
|
||||
|
||||
# MistralAI
|
||||
- name: ministral3/think
|
||||
title: Ministral 3 Thinking
|
||||
- name: ministral3/vision
|
||||
title: Ministral 3 Vision
|
||||
- name: magistral/think
|
||||
title: Magistral Thinking
|
||||
- name: magistral/vision
|
||||
title: Magistral Vision
|
||||
- name: ministral
|
||||
title: Ministral
|
||||
- name: mistral-small
|
||||
title: Mistral Small 3.1/3.2
|
||||
- name: voxtral
|
||||
title: Voxtral
|
||||
- name: devstral
|
||||
title: Devstral
|
||||
- name: mistral
|
||||
title: Mistral 7B
|
||||
|
||||
# Meta
|
||||
- name: llama-4
|
||||
title: Llama 4
|
||||
- name: llama-2
|
||||
title: Llama 2
|
||||
|
||||
# Alibaba
|
||||
- name: qwen3-next
|
||||
title: Qwen 3 Next
|
||||
- name: qwen3
|
||||
title: Qwen 3
|
||||
|
||||
# Google
|
||||
- name: gemma3n
|
||||
title: Gemma 3n
|
||||
|
||||
# Swiss AI
|
||||
- name: apertus
|
||||
title: Apertus
|
||||
|
||||
# GPT-OSS
|
||||
- name: gpt-oss
|
||||
title: GPT-OSS
|
||||
- name: seed-oss
|
||||
title: Seed-OSS
|
||||
|
||||
# Microsoft
|
||||
- name: phi
|
||||
title: Phi
|
||||
|
||||
# SmolVLM
|
||||
- name: smolvlm2
|
||||
title: SmolVLM 2
|
||||
|
||||
# IBM
|
||||
- name: granite4
|
||||
title: Granite 4
|
||||
|
||||
# LiquidAI
|
||||
- name: LiquidAI
|
||||
title: Liquid Foundation Models 2
|
||||
|
||||
# Other
|
||||
- name: hunyuan
|
||||
title: Hunyuan
|
||||
- name: jamba
|
||||
title: Jamba
|
||||
- name: orpheus
|
||||
title: Orpheus
|
||||
424
docs/scripts/generate_examples_docs.py
Executable file
424
docs/scripts/generate_examples_docs.py
Executable file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
auto generate example docs from allowlist
|
||||
"""
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
# Paths
|
||||
THIS = Path(__file__).resolve()
|
||||
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
|
||||
EXAMPLES_DIR = ROOT / "examples"
|
||||
OUTPUT_DIR = ROOT / "docs" / "models"
|
||||
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
|
||||
|
||||
|
||||
def slugify(name: str) -> str:
|
||||
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
|
||||
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
|
||||
s = re.sub(r"\s+", "-", s).strip("-").lower()
|
||||
return s or "example"
|
||||
|
||||
|
||||
def read_allowlist():
|
||||
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
items = data.get("examples", [])
|
||||
if not isinstance(items, list):
|
||||
raise ValueError("`examples` must be a list in examples-allowlist.yml")
|
||||
return items
|
||||
|
||||
|
||||
def find_readme(folder: Path) -> Path | None:
|
||||
for name in ("README.md", "Readme.md", "readme.md"):
|
||||
p = folder / name
|
||||
if p.exists():
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def remove_first_h1(md: str) -> tuple[str, str | None]:
|
||||
"""
|
||||
Remove the first H1 from markdown and return (modified_md, h1_title).
|
||||
The H1 is removed since we use the frontmatter title instead.
|
||||
"""
|
||||
lines = md.splitlines()
|
||||
result = []
|
||||
h1_title = None
|
||||
skipped_first = False
|
||||
|
||||
for line in lines:
|
||||
if not skipped_first and line.startswith("# "):
|
||||
h1_title = line[2:].strip()
|
||||
skipped_first = True
|
||||
continue
|
||||
result.append(line)
|
||||
|
||||
return "\n".join(result), h1_title
|
||||
|
||||
|
||||
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
|
||||
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
||||
|
||||
|
||||
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
|
||||
"""
|
||||
Copy local image assets referenced in markdown to
|
||||
docs/examples/assets/... and rewrite the links.
|
||||
"""
|
||||
dest_assets = dest_assets_root / "assets"
|
||||
|
||||
def repl(m):
|
||||
url = m.group(1).strip()
|
||||
if re.match(r"^(https?:)?//", url):
|
||||
return m.group(0) # leave remote URLs
|
||||
src_path = (src_dir / url).resolve()
|
||||
if not src_path.exists():
|
||||
return m.group(0) # leave as-is if not found
|
||||
rel = src_path.relative_to(src_dir)
|
||||
# Create a unique asset path based on source directory name
|
||||
asset_name = src_dir.name.replace("/", "-")
|
||||
dest_path = dest_assets / asset_name / rel
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dest_path)
|
||||
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
|
||||
return m.group(0).replace(url, new_rel)
|
||||
|
||||
return IMG_RE.sub(repl, md)
|
||||
|
||||
|
||||
def rewrite_readme_links(
|
||||
md: str,
|
||||
src_dir: Path,
|
||||
examples_dir: Path,
|
||||
parent_index_only: set,
|
||||
current_src_path: str,
|
||||
allowlist_entries: set,
|
||||
current_output_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Rewrite links between README.md files to point to the correct .qmd files.
|
||||
"""
|
||||
|
||||
def repl(m):
|
||||
text = m.group(1)
|
||||
url = m.group(2).strip()
|
||||
|
||||
# Skip remote URLs and anchor links
|
||||
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
|
||||
return m.group(0)
|
||||
|
||||
# Skip non-markdown files
|
||||
if not url.lower().endswith(".md"):
|
||||
return m.group(0)
|
||||
|
||||
# Resolve the target path
|
||||
try:
|
||||
target_path = (src_dir / url).resolve()
|
||||
|
||||
# Check if target is outside examples_dir
|
||||
try:
|
||||
rel_path = target_path.relative_to(examples_dir)
|
||||
except ValueError:
|
||||
# Target is outside examples_dir, leave as-is
|
||||
return m.group(0)
|
||||
|
||||
parts = list(rel_path.parts)
|
||||
|
||||
# Determine the output path for the target
|
||||
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
|
||||
# This is a README link
|
||||
if len(parts) == 1:
|
||||
# Link to root README -> index.qmd
|
||||
target_output = "index.qmd"
|
||||
elif len(parts) == 2:
|
||||
if parts[0] == ".":
|
||||
# Current directory README
|
||||
target_output = "index.qmd"
|
||||
else:
|
||||
# subdir/README.md
|
||||
parent_dir = parts[0]
|
||||
if parent_dir in parent_index_only:
|
||||
target_output = f"{parent_dir}/index.qmd"
|
||||
else:
|
||||
target_output = f"{parent_dir}.qmd"
|
||||
else:
|
||||
# Deeper nesting: parent/subdir/README.md
|
||||
# Build the full path like "parent/subdir"
|
||||
full_path = "/".join(parts[:-1]) # Remove README.md
|
||||
# Check if this exact path is in allowlist
|
||||
if full_path in allowlist_entries:
|
||||
# This is a sub-entry with its own entry -> use .qmd
|
||||
target_output = f"{full_path}.qmd"
|
||||
elif parts[0] == ".":
|
||||
# ./subdir/README.md -> check if subdir has own entry
|
||||
subdir = parts[1]
|
||||
if subdir in parent_index_only:
|
||||
target_output = f"{subdir}/index.qmd"
|
||||
else:
|
||||
target_output = f"{subdir}.qmd"
|
||||
else:
|
||||
# parent/subdir where parent doesn't have own entry
|
||||
target_output = f"{full_path}/index.qmd"
|
||||
else:
|
||||
# Regular .md file -> convert to .qmd, keep path structure
|
||||
target_output = "/".join(parts)[:-2] + "qmd"
|
||||
|
||||
# Compute relative path from current output file to target
|
||||
current_parts = current_output_path.split("/")
|
||||
target_parts = target_output.split("/")
|
||||
|
||||
# Special case: if current is a subdir file and target is a single-component file at root
|
||||
# Example: current="magistral/vision", target="magistral.qmd"
|
||||
if len(current_parts) > 1 and len(target_parts) == 1:
|
||||
# Current is in subdir, target is at root level
|
||||
# Go up to root: ../ for each level
|
||||
up_count = len(current_parts) - 1
|
||||
rel_parts = [".."] * up_count + [target_parts[0]]
|
||||
new_url = "/".join(rel_parts)
|
||||
else:
|
||||
# Find common prefix
|
||||
i = 0
|
||||
while (
|
||||
i < min(len(current_parts) - 1, len(target_parts))
|
||||
and current_parts[i] == target_parts[i]
|
||||
):
|
||||
i += 1
|
||||
|
||||
# Build relative path: go up (../) then down to target
|
||||
up_count = len(current_parts) - 1 - i
|
||||
rel_parts = [".."] * up_count + target_parts[i:]
|
||||
|
||||
if not rel_parts or rel_parts == [".."]:
|
||||
# Points to same directory or parent
|
||||
new_url = "/".join(rel_parts) if rel_parts else "."
|
||||
else:
|
||||
new_url = "/".join(rel_parts)
|
||||
|
||||
return f"[{text}]({new_url})"
|
||||
except (ValueError, IndexError):
|
||||
return m.group(0)
|
||||
|
||||
return LINK_RE.sub(repl, md)
|
||||
|
||||
|
||||
def write_qmd(out_path: Path, title: str, body_md: str):
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||
out_path.write_text(fm + body_md, encoding="utf-8")
|
||||
|
||||
|
||||
def update_quarto_yml(generated: list[tuple[str, str, str]]):
|
||||
"""
|
||||
Update _quarto.yml with the generated example files in the correct order.
|
||||
This keeps the sidebar in sync with the allowlist.
|
||||
|
||||
Model Guides is now nested under "Getting Started" section.
|
||||
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
|
||||
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
|
||||
"""
|
||||
quarto_yml = ROOT / "_quarto.yml"
|
||||
if not quarto_yml.exists():
|
||||
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
|
||||
return
|
||||
|
||||
content = quarto_yml.read_text(encoding="utf-8")
|
||||
|
||||
# First pass: find all parents that have sub-entries
|
||||
parents_with_subs = set()
|
||||
for path, _name, _title in generated:
|
||||
if "/" in path:
|
||||
parent = path.split("/")[0]
|
||||
parents_with_subs.add(parent)
|
||||
|
||||
# Build the YAML contents while preserving allowlist order
|
||||
lines = []
|
||||
processed_sections = set()
|
||||
|
||||
for path, _name, title in generated:
|
||||
# Check if this is a parent page that has sub-pages
|
||||
if path in parents_with_subs:
|
||||
# This is a parent page with sub-pages - create a nested section
|
||||
if path not in processed_sections:
|
||||
processed_sections.add(path)
|
||||
section_title = (
|
||||
title or path.replace("-", " ").replace("_", " ").title()
|
||||
)
|
||||
lines.append(f' - section: "{section_title}"')
|
||||
lines.append(" contents:")
|
||||
# Add the parent page first
|
||||
lines.append(f" - docs/models/{path}.qmd")
|
||||
# Then add all sub-pages
|
||||
for sub_path, _sub_name, _sub_title in generated:
|
||||
if "/" in sub_path and sub_path.split("/")[0] == path:
|
||||
lines.append(
|
||||
f" - docs/models/{sub_path}.qmd"
|
||||
)
|
||||
elif "/" not in path:
|
||||
# This is a flat item with no sub-pages
|
||||
# Skip if it was already included as part of a parent section
|
||||
if path not in processed_sections:
|
||||
lines.append(f" - docs/models/{path}.qmd")
|
||||
|
||||
yaml_content = "\n".join(lines) + "\n"
|
||||
|
||||
# Pattern to match only the Model Guides contents, stopping at the next item
|
||||
# in Getting Started (lines starting with 12 spaces: same level as the section)
|
||||
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
|
||||
|
||||
def replacement(match):
|
||||
prefix = match.group(1)
|
||||
return prefix + "\n" + yaml_content
|
||||
|
||||
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
|
||||
|
||||
if new_content != content:
|
||||
quarto_yml.write_text(new_content, encoding="utf-8")
|
||||
print(f"Updated {quarto_yml}")
|
||||
else:
|
||||
print(f"No changes needed for {quarto_yml}")
|
||||
|
||||
|
||||
def main():
|
||||
allow = read_allowlist()
|
||||
if not EXAMPLES_DIR.exists():
|
||||
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
|
||||
return
|
||||
|
||||
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# First pass: identify which parents have their own entry vs only sub-entries
|
||||
parent_entries = set() # Parents that have their own entry
|
||||
parent_with_subs = set() # Parents that have sub-entries
|
||||
allowlist_entries = set() # All entries in allowlist
|
||||
|
||||
for item in allow:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
else:
|
||||
name = item.get("name")
|
||||
|
||||
allowlist_entries.add(name)
|
||||
|
||||
if "/" in name:
|
||||
parent = name.split("/")[0]
|
||||
parent_with_subs.add(parent)
|
||||
else:
|
||||
parent_entries.add(name)
|
||||
|
||||
# Parents with subs that DON'T have their own entry -> use index.qmd
|
||||
parent_index_only = parent_with_subs - parent_entries
|
||||
|
||||
generated = []
|
||||
seen_dirs = set() # Track which parent directories we've created index for
|
||||
|
||||
for item in allow:
|
||||
if isinstance(item, str):
|
||||
name = item
|
||||
title = None
|
||||
else:
|
||||
name = item.get("name")
|
||||
title = item.get("title")
|
||||
|
||||
if not name:
|
||||
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
src_dir = EXAMPLES_DIR / name
|
||||
if not src_dir.exists() or not src_dir.is_dir():
|
||||
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
readme = find_readme(src_dir)
|
||||
if not readme:
|
||||
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
|
||||
continue
|
||||
|
||||
md = readme.read_text(encoding="utf-8")
|
||||
|
||||
# Determine output path first (needed for link rewriting)
|
||||
parts = name.split("/")
|
||||
if len(parts) == 1:
|
||||
# Simple case: no subdirectory
|
||||
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
|
||||
sidebar_path = parts[0]
|
||||
else:
|
||||
# Has subdirectory: e.g., magistral/think
|
||||
parent = parts[0]
|
||||
child = "-".join(parts[1:]) # handle nested subdirs
|
||||
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
|
||||
sidebar_path = f"{parent}/{child}"
|
||||
|
||||
# Remove the first H1 (we use frontmatter title instead)
|
||||
md, _ = remove_first_h1(md)
|
||||
# Rewrite links between README files
|
||||
md = rewrite_readme_links(
|
||||
md,
|
||||
src_dir,
|
||||
EXAMPLES_DIR,
|
||||
parent_index_only,
|
||||
name,
|
||||
allowlist_entries,
|
||||
sidebar_path,
|
||||
)
|
||||
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
|
||||
|
||||
# Handle parent page generation for sub-entries
|
||||
if len(parts) > 1:
|
||||
# Has subdirectory: e.g., magistral/think
|
||||
parent = parts[0]
|
||||
|
||||
# Create parent.qmd if not already done and parent doesn't have own entry
|
||||
if parent not in seen_dirs and parent in parent_index_only:
|
||||
parent_readme = find_readme(EXAMPLES_DIR / parent)
|
||||
if parent_readme:
|
||||
parent_md = parent_readme.read_text(encoding="utf-8")
|
||||
parent_md, _ = remove_first_h1(parent_md)
|
||||
parent_md = rewrite_readme_links(
|
||||
parent_md,
|
||||
EXAMPLES_DIR / parent,
|
||||
EXAMPLES_DIR,
|
||||
parent_index_only,
|
||||
parent,
|
||||
allowlist_entries,
|
||||
parent,
|
||||
)
|
||||
parent_md = rewrite_and_copy_assets(
|
||||
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
|
||||
)
|
||||
parent_title = parent.replace("-", " ").replace("_", " ").title()
|
||||
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
|
||||
generated.append((parent, parent, parent_title))
|
||||
seen_dirs.add(parent)
|
||||
|
||||
if not title:
|
||||
title = name.replace("/", " ").replace("-", " ").title()
|
||||
|
||||
write_qmd(out_path, title, md)
|
||||
generated.append((sidebar_path, name, title))
|
||||
|
||||
# Index page - preserve allowlist order
|
||||
if generated:
|
||||
listing = "\n".join(
|
||||
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
|
||||
)
|
||||
index_md = (
|
||||
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
|
||||
+ listing
|
||||
+ "\n"
|
||||
)
|
||||
index_fm = (
|
||||
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
|
||||
)
|
||||
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
|
||||
|
||||
# Auto-update _quarto.yml to keep sidebar in sync
|
||||
update_quarto_yml(generated)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -253,7 +253,6 @@
|
||||
"source": [
|
||||
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
||||
"\n",
|
||||
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"set_pytorch_cuda_alloc_conf()"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -32,6 +32,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,6 +28,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -29,6 +29,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -28,6 +28,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,6 +41,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
@@ -41,6 +41,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
trackio_project_name:
|
||||
trackio_run_name:
|
||||
trackio_space_id:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
43
examples/internvl3_5/README.md
Normal file
43
examples/internvl3_5/README.md
Normal file
@@ -0,0 +1,43 @@
|
||||
# Finetune OpenGV's InternVL with Axolotl
|
||||
|
||||
[InternVL 3.5](https://huggingface.co/OpenGVLab/InternVL3_5-8B-HF) is a family of powerful vision-language models supporting dynamic resolution and multi-image understanding by OpenGV. It features a ViT-style vision encoder and strong language model backbone for tasks like visual question answering, OCR, and scene text understanding.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install `timm` for vision model support:
|
||||
|
||||
```bash
|
||||
pip install timm==1.0.19
|
||||
```
|
||||
|
||||
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/internvl3_5/internvl3_5-8b-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 8.21 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the multi-modal format as seen [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [InternVL Paper](https://huggingface.co/papers/2508.18265)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
61
examples/internvl3_5/internvl3_5-8b-qlora.yml
Normal file
61
examples/internvl3_5/internvl3_5-8b-qlora.yml
Normal file
@@ -0,0 +1,61 @@
|
||||
base_model: OpenGVLab/InternVL3_5-8B-HF
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
47
examples/kimi-linear/README.md
Normal file
47
examples/kimi-linear/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# Finetune MoonshotAI's Kimi Linear with Axolotl
|
||||
|
||||
[Kimi Linear](https://huggingface.co/collections/moonshotai/kimi-linear-a3b) is a MoE model (48B total, 3B active) by MoonshotAI using a hybrid linear attention architecture to achieve a 1M token context length. It uses Kimi Delta Attention (KDA), a refined version of Gated DeltaNet that reduces KV cache size by up to 75% and boosts decoding throughput by up to 6x for long contexts.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
**Note:** Axolotl uses experimental training code for Kimi Linear as their original modeling code is inference-only.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install CCE via [docs](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/kimi-linear/kimi-48b-lora.yaml
|
||||
```
|
||||
|
||||
This config uses about 98.7GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning!
|
||||
|
||||
### TIPS
|
||||
|
||||
- Kimi Linear requires `trust_remote_code: true`.
|
||||
- You can run a full finetuning by removing the `adapter: lora` and `load_in_8bit: true`.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html)
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template)
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
See 👉 [docs](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
This is not yet compatible with MoE kernels from transformers v5.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Kimi Linear Paper](https://huggingface.co/papers/2510.26692)
|
||||
- [Kimi Linear GitHub](https://github.com/MoonshotAI/Kimi-Linear)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
81
examples/kimi-linear/kimi-48b-lora.yaml
Normal file
@@ -0,0 +1,81 @@
|
||||
base_model: moonshotai/Kimi-Linear-48B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
trust_remote_code: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.2
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -29,7 +29,6 @@ flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
save_strategy: no
|
||||
torch_compile: true
|
||||
|
||||
wandb_project:
|
||||
|
||||
@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Magistral Small 2507](https://huggingface.co/mist
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
@@ -5,7 +5,8 @@ This guide covers fine-tuning [Magistral Small 2509](https://huggingface.co/mist
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
- Installed Axolotl from source (see [main README](../README.md#getting-started))
|
||||
|
||||
- Installed Axolotl from source (see [main README](../README.md))
|
||||
|
||||
## Getting started
|
||||
|
||||
|
||||
39
examples/mimo/README.md
Normal file
39
examples/mimo/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Finetune Xiaomi's MiMo with Axolotl
|
||||
|
||||
[MiMo](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL) is a family of models trained from scratch for reasoning tasks, incorporating **Multiple-Token Prediction (MTP)** as an additional training objective for enhanced performance and faster inference. Pre-trained on ~25T tokens with a three-stage data mixture strategy and optimized reasoning pattern density.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/mimo/mimo-7b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 17.2 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for MiMo in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MiMo Paper](https://arxiv.org/abs/2505.07608)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
67
examples/mimo/mimo-7b-qlora.yaml
Normal file
67
examples/mimo/mimo-7b-qlora.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: XiaomiMiMo/MiMo-7B-RL
|
||||
trust_remote_code: true
|
||||
revision_of_model: 6299b5a
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# CCE - N/A as of now
|
||||
# plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
50
examples/ministral/README.md
Normal file
50
examples/ministral/README.md
Normal file
@@ -0,0 +1,50 @@
|
||||
# Finetune Ministral with Axolotl
|
||||
|
||||
Ministral is a family of openweight models from MistralAI found on [HuggingFace](mistralai/Ministral-8B-Instruct-2410). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/ministral/ministral-small-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 8.76 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Ministral Blog](https://mistral.ai/news/ministraux)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
67
examples/ministral/ministral-small-qlora.yaml
Normal file
67
examples/ministral/ministral-small-qlora.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: mistralai/Ministral-8B-Instruct-2410
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
79
examples/ministral3/README.md
Normal file
79
examples/ministral3/README.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# Finetune Ministral3 with Axolotl
|
||||
|
||||
Ministral3 is a family of open-weight models from MistralAI found on [HuggingFace](https://huggingface.co/collections/mistralai/ministral-3). This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
Please see [Thinking](#thinking) and [Vision](#vision) for their respective fine-tuning.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for these releases.
|
||||
|
||||
Note: This is still experimental given it is based on transformers v5 RC.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Swap to the Axolotl transformers v5 branch
|
||||
|
||||
```bash
|
||||
cp examples/ministral3/ministral3-3b-qlora.yaml ministral3-3b-qlora.yaml
|
||||
|
||||
git fetch
|
||||
git checkout transformers-v5
|
||||
|
||||
# Install packages for transformers v5
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
4. Run the fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train ministral3-3b-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
|
||||
### Tips
|
||||
|
||||
- We recommend adding the same/similar SystemPrompt that the model is tuned for. You can find this within the repo's files titled `SYSTEM_PROMPT.txt`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
### Thinking
|
||||
|
||||
Ministral3 2512 model supports thinking capabilities, enabling Chain-of-Thought reasoning with explicit thinking steps.
|
||||
|
||||
📚 **[See the Thinking fine-tuning guide →](./think/README.md)**
|
||||
|
||||
### Vision
|
||||
|
||||
Ministral3 2512 model also supports vision capabilities.
|
||||
|
||||
📚 **[See the Vision fine-tuning guide →](./vision/README.md)**
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Mistral3 Blog](https://mistral.ai/news/mistral-3)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
67
examples/ministral3/ministral3-3b-qlora.yaml
Normal file
67
examples/ministral3/ministral3-3b-qlora.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
74
examples/ministral3/think/README.md
Normal file
74
examples/ministral3/think/README.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Ministral3 2512 Thinking Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with thinking capabilities using Axolotl. The thinking model enables explicit Chain-of-Thought reasoning with separate thinking and response sections.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [main README](../README.md))
|
||||
|
||||
## Getting Started
|
||||
|
||||
Run the thinking model fine-tuning:
|
||||
|
||||
```bash
|
||||
axolotl train examples/ministral3/think/ministral3-3b-think-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 4.76 GiB VRAM.
|
||||
|
||||
### Tips
|
||||
|
||||
- Dataset uses multi-content format with `type: thinking` support. See [Dataset Format](#dataset-format) below.
|
||||
- You cannot mix `content: str` and `content: list[dict]`, otherwise, dataset loading will fail. Keep it consistent.
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The thinking model requires the multi-content dataset format with support for an extra `role: thinking` within system and assistant messages.
|
||||
|
||||
Example format:
|
||||
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": "{SYSTEM_PROMPT}"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{ "type": "text", "text": "Solve this step by step: What is 15% of 240?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "I need to calculate 15% of 240. First, I'll convert 15% to decimal: 0.15. Then multiply: 0.15 × 240 = 36."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "To find 15% of 240, I'll multiply 240 by 0.15:\n\n240 × 0.15 = 36\n\nTherefore, 15% of 240 is 36."
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Advanced Options
|
||||
|
||||
The `thinking` section supports an optional `closed` parameter:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Internal reasoning here...",
|
||||
"closed": true // Default: true, controls adding the closing [/THINK] tag
|
||||
}
|
||||
```
|
||||
67
examples/ministral3/think/ministral3-3b-think-qlora.yaml
Normal file
67
examples/ministral3/think/ministral3-3b-think-qlora.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: Nanobit/text-think-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
58
examples/ministral3/vision/README.md
Normal file
58
examples/ministral3/vision/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Ministral3 2512 Vision Fine-tuning
|
||||
|
||||
This guide covers fine-tuning [Ministral3 2512](https://huggingface.co/collections/mistralai/ministral-3) with vision capabilities using Axolotl.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl from source (see [main README](../README.md))
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install the required vision lib:
|
||||
```bash
|
||||
pip install 'mistral-common[opencv]==1.8.6'
|
||||
```
|
||||
|
||||
2. Download the example dataset image:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
```
|
||||
|
||||
3. Run the fine-tuning:
|
||||
```bash
|
||||
axolotl train examples/ministral3/vision/ministral3-3b-vision-qlora.yml
|
||||
```
|
||||
|
||||
WARNING: The loss and grad norm will be much higher than normal at first. We suspect this to be inherent to the model as of the moment. If anyone would like to submit a fix for this, we are happy to take a look.
|
||||
|
||||
### Tips
|
||||
|
||||
Key differences from text-only model:
|
||||
- Multi-modal dataset format required
|
||||
- Sample packing not supported
|
||||
|
||||
## Dataset Format
|
||||
|
||||
The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
One exception is that, passing `"image": PIL.Image` is not supported. MistralTokenizer only supports `path`, `url`, and `base64` for now.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": [{ "type": "text", "text": "{SYSTEM_PROMPT}"}]},
|
||||
{"role": "user", "content": [
|
||||
{ "type": "text", "text": "What's in this image?"},
|
||||
{"type": "image", "path": "path/to/image.jpg" }
|
||||
]},
|
||||
{"role": "assistant", "content": [{ "type": "text", "text": "..." }]},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Sample Packing is not supported for multi-modality training currently.
|
||||
64
examples/ministral3/vision/ministral3-3b-vision-qlora.yml
Normal file
64
examples/ministral3/vision/ministral3-3b-vision-qlora.yml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: mistralai/Ministral-3-3B-Reasoning-2512
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
# sample dataset below requires downloading image in advance
|
||||
# wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -5,6 +5,7 @@ This guide covers fine-tuning [Mistral Small 3.1](mistralai/Mistral-Small-3.1-24
|
||||
## Prerequisites
|
||||
|
||||
Before starting, ensure you have:
|
||||
|
||||
- Installed Axolotl (see [Installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Getting Started
|
||||
@@ -6,25 +6,17 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
|
||||
# Install Cut Cross Entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/olmo3/olmo3-7b-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
This uses about 11.3 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
|
||||
@@ -42,10 +42,10 @@ wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
|
||||
42
examples/plano/README.md
Normal file
42
examples/plano/README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Finetune Katanemo's Plano-Orchestrator with Axolotl
|
||||
|
||||
[Plano-Orchestrator](https://huggingface.co/collections/katanemo/plano-orchestrator) is a family of 4B and 30B-A3B routing and orchestration models designed for multi-agent systems. It analyzes user intent and conversation context to make precise routing decisions, excelling at multi-turn context understanding, multi-intent detection, and context-dependent routing.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/plano/plano-4b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 5.1 GiB VRAM. Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Orchestration Prompt
|
||||
|
||||
Plano-Orchestrator uses a specific orchestration prompt format for routing/agent decisions. Please check the [official model card](https://huggingface.co/katanemo/Plano-Orchestrator-4B) for proper prompt formatting and the `ORCHESTRATION_PROMPT` template.
|
||||
|
||||
### Tips
|
||||
|
||||
- To use the larger [Plano-Orchestrator-30B-A3B](https://huggingface.co/katanemo/Plano-Orchestrator-30B-A3B) MoE model, simply change `base_model: katanemo/Plano-Orchestrator-30B-A3B` in the config and enable multi-GPU training if needed.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Plano GitHub](https://github.com/katanemo/plano)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
65
examples/plano/plano-4b-qlora.yaml
Normal file
65
examples/plano/plano-4b-qlora.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
base_model: katanemo/Plano-Orchestrator-4B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
chat_template: qwen3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
67
examples/qat_nvfp4/Gemma3-12B_baseline.yml
Normal file
67
examples/qat_nvfp4/Gemma3-12B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/out_gemma/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 4e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
72
examples/qat_nvfp4/Gemma3-12B_qat.yml
Normal file
72
examples/qat_nvfp4/Gemma3-12B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out_gemma/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 4e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
67
examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml
Normal file
67
examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Math finetuning configuration for Gemma3-12B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/out_math_gemma/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 3e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
72
examples/qat_nvfp4/Math-Gemma3-12B_qat.yml
Normal file
72
examples/qat_nvfp4/Math-Gemma3-12B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Math finetuning configuration for Gemma3-12B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/qat_out_math_gemma/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 3e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
68
examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml
Normal file
68
examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: google/gemma-3-27b-it
|
||||
# Math finetuning configuration for Gemma3-27B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/out_math_gemma27/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
73
examples/qat_nvfp4/Math-Gemma3-27B_qat.yml
Normal file
73
examples/qat_nvfp4/Math-Gemma3-27B_qat.yml
Normal file
@@ -0,0 +1,73 @@
|
||||
base_model: google/gemma-3-27b-it
|
||||
# Math finetuning configuration for Gemma3-27B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: gemma3
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/qat_out_math_gemma27/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
67
examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml
Normal file
67
examples/qat_nvfp4/Math-Qwen2.5-72B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/out_math_72b/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
72
examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml
Normal file
72
examples/qat_nvfp4/Math-Qwen2.5-72B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Math finetuning configuration for Qwen2.5-72B (non-instruct)
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-CoT
|
||||
type: chat_template
|
||||
|
||||
output_dir: ./outputs/qat_out_math_72b/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 8
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 5e-6
|
||||
eta_min: 7e-7
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
67
examples/qat_nvfp4/Qwen2.5-72B_baseline.yml
Normal file
67
examples/qat_nvfp4/Qwen2.5-72B_baseline.yml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Alpaca finetuning configuration for Qwen2.5-72B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/out_qwen72b/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
72
examples/qat_nvfp4/Qwen2.5-72B_qat.yml
Normal file
72
examples/qat_nvfp4/Qwen2.5-72B_qat.yml
Normal file
@@ -0,0 +1,72 @@
|
||||
base_model: Qwen/Qwen2.5-72B
|
||||
# Alpaca finetuning configuration for Qwen2.5-72B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
seed: 42
|
||||
chat_template: qwen_25
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out_qwen72b/
|
||||
|
||||
sequence_len: 8096
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: nvfp4
|
||||
weight_dtype: nvfp4
|
||||
group_size: 16 # only group_size of 16 is supported with nvfp4
|
||||
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
# evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
fsdp_version: 2
|
||||
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: true
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/adamw-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
# Use random initialization for fair comparison
|
||||
reinit_weights: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
# Pretraining dataset
|
||||
pretraining_dataset:
|
||||
- path: allenai/c4
|
||||
name: en
|
||||
type: pretrain
|
||||
split: train
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/compare-adamw-pretrain
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project: dist_muon
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name: adamw
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
max_steps: 305
|
||||
|
||||
# AdamW optimizer settings (standard LR for AdamW)
|
||||
optimizer: adamw_torch_fused
|
||||
learning_rate: 0.0002
|
||||
weight_decay: 0.01
|
||||
lr_scheduler: cosine
|
||||
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_reshard_after_forward: true
|
||||
|
||||
special_tokens:
|
||||
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
70
examples/qwen2/muon-pretrain-fsdp2.yaml
Normal file
@@ -0,0 +1,70 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
# Use random initialization for fair comparison
|
||||
reinit_weights: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
# Pretraining dataset
|
||||
pretraining_dataset:
|
||||
- path: allenai/c4
|
||||
name: en
|
||||
type: pretrain
|
||||
split: train
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/compare-muon-pretrain
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project: dist_muon
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name: muon
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
max_steps: 305
|
||||
|
||||
# Muon optimizer settings
|
||||
optimizer: muon
|
||||
learning_rate: 0.02
|
||||
weight_decay: 0.01
|
||||
lr_scheduler: cosine
|
||||
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
fsdp_reshard_after_forward: true
|
||||
|
||||
special_tokens:
|
||||
46
examples/qwen3/README.md
Normal file
46
examples/qwen3/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Finetune Qwen3 with Axolotl
|
||||
|
||||
[Qwen3](https://huggingface.co/collections/Qwen/qwen3) are a family of open source models trained by Alibaba.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3/32b-qlora.yaml
|
||||
```
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### Chat template masking a few tokens off
|
||||
|
||||
If you notice that the `chat_template` masking for assistant prompts are off by a few tokens, please ensure that you are adding the below to the yaml.
|
||||
|
||||
```yaml
|
||||
chat_template: qwen3
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, please check the official model card as it depends on your reasoning mode.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3 Blog](https://qwenlm.github.io/blog/qwen3/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
42
examples/trinity/README.md
Normal file
42
examples/trinity/README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Finetune ArceeAI's Trinity with Axolotl
|
||||
|
||||
[Trinity](https://huggingface.co/collections/arcee-ai/trinity) is a family of open weight MoE models trained by Arcee.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24.9 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Arcee.ai team recommends `top_p: 0.75`, `temperature: 0.15`, `top_k: 50`, and `min_p: 0.06`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
68
examples/trinity/trinity-nano-preview-qlora.yaml
Normal file
68
examples/trinity/trinity-nano-preview-qlora.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: arcee-ai/Trinity-Nano-Preview
|
||||
trust_remote_code: true
|
||||
revision_of_model: 2ee94b0
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# CCE - N/A as of now
|
||||
# plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
# flash_attention: true # Not supported
|
||||
sdp_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -5,7 +5,7 @@ bitsandbytes==0.48.2
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
liger-kernel==0.6.3
|
||||
liger-kernel==0.6.4
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
@@ -14,20 +14,21 @@ huggingface_hub>=0.36.0
|
||||
peft>=0.18.0
|
||||
tokenizers>=0.22.1
|
||||
transformers==4.57.1
|
||||
accelerate==1.11.0
|
||||
datasets==4.4.1
|
||||
deepspeed>=0.17.0
|
||||
trl==0.25.0
|
||||
accelerate==1.12.0
|
||||
datasets==4.4.2
|
||||
deepspeed>=0.18.3
|
||||
trl==0.25.1
|
||||
hf_xet==1.2.0
|
||||
kernels>=0.9.0
|
||||
trackio
|
||||
kernels==0.11.5
|
||||
trackio>=0.13.0
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.49.1
|
||||
gradio>=6.2.0,<7.0
|
||||
|
||||
modal==1.0.2
|
||||
modal==1.3.0.post1
|
||||
pydantic>=2.10.6
|
||||
addict
|
||||
fire
|
||||
@@ -62,14 +63,13 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.13.0
|
||||
torchao==0.15.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.7
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.8.5
|
||||
mistral-common==1.8.6
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
|
||||
)
|
||||
|
||||
3
setup.py
3
setup.py
@@ -66,7 +66,6 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
elif (major, minor) >= (2, 8):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
||||
@@ -157,7 +156,7 @@ extras_require = {
|
||||
"came_pytorch==0.1.3",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
"ray[train]>=2.52.1",
|
||||
],
|
||||
"vllm": [
|
||||
"vllm==0.10.0",
|
||||
|
||||
@@ -24,8 +24,7 @@ if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
# 1. Define a base image for your training job
|
||||
# must use torch 2.7.0 for vllm
|
||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
||||
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu128-2.9.1"
|
||||
|
||||
# 2. Define the Runtime Environment for the Training Job
|
||||
# This includes start commands and environment variables.a
|
||||
|
||||
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu126-2.7.1"
|
||||
docker_tag = "main-py3.11-cu128-2.9.1"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
|
||||
@@ -26,6 +26,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trackio_ import setup_trackio_env_vars
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
@@ -227,6 +228,7 @@ def load_cfg(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": is_torch_bf16_gpu_available(),
|
||||
"fp8": compute_supports_fp8(),
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
@@ -245,6 +247,7 @@ def load_cfg(
|
||||
setup_wandb_env_vars(cfg)
|
||||
setup_mlflow_env_vars(cfg)
|
||||
setup_comet_env_vars(cfg)
|
||||
setup_trackio_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
|
||||
@@ -259,3 +262,11 @@ def load_cfg(
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def compute_supports_fp8() -> bool:
|
||||
try:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability >= (9, 0)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
@@ -288,8 +288,8 @@ def do_inference_gradio(
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
@@ -45,6 +45,7 @@ def cli():
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
set_misc_env()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Union
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, TorchAoConfig
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.quantization import (
|
||||
TorchAOQuantDType,
|
||||
@@ -66,6 +66,11 @@ def do_quantize(
|
||||
|
||||
LOG.info(f"Loading model from {model_path}.")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -107,6 +112,10 @@ def do_quantize(
|
||||
save_jinja_files=cfg.tokenizer_save_jinja_files,
|
||||
)
|
||||
|
||||
if processor:
|
||||
LOG.info(f"Saving processor to: {str(Path(output_dir) / 'quantized')}.")
|
||||
processor.save_pretrained(str(Path(output_dir) / "quantized"))
|
||||
|
||||
if hub_model_id:
|
||||
hub_model_id = (
|
||||
hub_model_id.rstrip("-")
|
||||
@@ -114,6 +123,8 @@ def do_quantize(
|
||||
)
|
||||
model.push_to_hub(hub_model_id, safe_serialization=False)
|
||||
tokenizer.push_to_hub(hub_model_id)
|
||||
if processor:
|
||||
processor.push_to_hub(hub_model_id)
|
||||
LOG.info(f"Quantized model pushed to: {hub_model_id}.")
|
||||
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
|
||||
|
||||
@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
|
||||
outputs=[masked_preview, html_out],
|
||||
)
|
||||
|
||||
demo.queue().launch(
|
||||
show_api=False,
|
||||
demo.launch(
|
||||
footer_links=["gradio", "settings"],
|
||||
share=cfg.get("gradio_share", True),
|
||||
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
||||
server_port=cfg.get("gradio_server_port", None),
|
||||
|
||||
@@ -17,4 +17,5 @@ MOE_ARCH_BLOCK = {
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
}
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
"""
|
||||
monkeypatch for flex + packing
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import Cache, PretrainedConfig
|
||||
from transformers.masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_preprocess_mask_arguments,
|
||||
and_masks,
|
||||
causal_mask_function,
|
||||
or_masks,
|
||||
)
|
||||
from transformers.utils import is_torch_greater_or_equal
|
||||
|
||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
"""
|
||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
||||
to what is needed in the `modeling_xxx.py` files).
|
||||
|
||||
Args:
|
||||
config (`PretrainedConfig`):
|
||||
The model config.
|
||||
input_embeds (`torch.Tensor`):
|
||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
||||
batch size, query length and dtype.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
past_key_values (`Cache`, optional):
|
||||
The past key values, if we use a cache.
|
||||
or_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
and_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if (
|
||||
past_key_values
|
||||
and hasattr(past_key_values, "is_sliding")
|
||||
and False in past_key_values.is_sliding
|
||||
):
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
original_attention_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.clone().to(cache_position.device)
|
||||
)
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
)
|
||||
if early_exit:
|
||||
return attention_mask
|
||||
|
||||
batch_size, total_seq_len = cache_position.shape
|
||||
key_length = total_seq_len
|
||||
document_ids = torch.nn.functional.pad(
|
||||
original_attention_mask, value=0, pad=(0, key_length)
|
||||
)
|
||||
|
||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
||||
if attention_mask is not None:
|
||||
|
||||
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = (
|
||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
)
|
||||
final_mask = causal_mask_ & document_mask
|
||||
return final_mask
|
||||
|
||||
mask_factory_function = causal_doc_mask_mod
|
||||
else:
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
allow_is_causal_skip = (
|
||||
not past_key_values.is_compileable if past_key_values is not None else True
|
||||
)
|
||||
|
||||
# Allow slight deviations from causal mask
|
||||
if or_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
|
||||
# We now create the mask
|
||||
causal_mask = mask_interface(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
kv_length=kv_length,
|
||||
kv_offset=kv_offset,
|
||||
mask_function=mask_factory_function,
|
||||
attention_mask=attention_mask,
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
def patch_create_causal_mask(model_type):
|
||||
import transformers.masking_utils
|
||||
|
||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
||||
|
||||
if model_type:
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(module_path)
|
||||
module.create_causal_mask = create_causal_mask
|
||||
del sys.modules[module_path]
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -35,6 +35,7 @@ from axolotl.utils import (
|
||||
is_comet_available,
|
||||
is_mlflow_available,
|
||||
is_opentelemetry_available,
|
||||
is_trackio_available,
|
||||
)
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
@@ -147,6 +148,14 @@ class TrainerBuilderBase(abc.ABC):
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_trackio and is_trackio_available():
|
||||
from axolotl.utils.callbacks.trackio_ import (
|
||||
SaveAxolotlConfigtoTrackioCallback,
|
||||
)
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_otel_metrics and is_opentelemetry_available():
|
||||
from axolotl.utils.callbacks.opentelemetry import (
|
||||
OpenTelemetryMetricsCallback,
|
||||
@@ -281,11 +290,22 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||
|
||||
if self.cfg.optimizer == "muon":
|
||||
from axolotl.contribs.mit.muon import (
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
|
||||
if device_mesh is not None:
|
||||
from axolotl.contribs.mit.muon.dist_muon import (
|
||||
DistMuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = DistMuonOptimizerFactory
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
else:
|
||||
from axolotl.contribs.mit.muon import (
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "dion":
|
||||
from axolotl.contribs.mit.dion import (
|
||||
@@ -423,6 +443,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
if self.cfg.use_trackio:
|
||||
report_to.append("trackio")
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
@@ -430,6 +452,8 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
elif self.cfg.use_trackio:
|
||||
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
|
||||
else:
|
||||
training_args_kwargs["run_name"] = None
|
||||
|
||||
|
||||
@@ -72,7 +72,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.include_tkps:
|
||||
callbacks.append(
|
||||
TokensPerSecondCallback(
|
||||
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
|
||||
)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial, wraps
|
||||
@@ -49,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state."
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
@@ -348,24 +352,33 @@ class AxolotlTrainer(
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
|
||||
# track number of tokens for tokens per second calculation
|
||||
if self.args.include_tkps:
|
||||
if self.args.include_tkps and model.training:
|
||||
inputs_key = "labels" if "labels" in inputs else "input_ids"
|
||||
num_tokens = (inputs[inputs_key] != -100).sum()
|
||||
trainable_tokens = (inputs[inputs_key] != -100).sum()
|
||||
total_tokens = inputs[inputs_key].numel()
|
||||
total_tokens = torch.tensor(total_tokens, device=inputs[inputs_key].device)
|
||||
|
||||
if is_distributed():
|
||||
torch.distributed.all_reduce(
|
||||
num_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
trainable_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
if hasattr(self.state, "num_tokens"):
|
||||
self.state.num_tokens = (
|
||||
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
|
||||
torch.distributed.all_reduce(
|
||||
total_tokens, op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
else:
|
||||
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
|
||||
|
||||
if hasattr(self.state, "total_tokens"):
|
||||
self.state.total_tokens += num_tokens
|
||||
else:
|
||||
self.state.total_tokens = num_tokens
|
||||
if not hasattr(self.state, "tokens"):
|
||||
self.state.tokens = {
|
||||
"trainable": torch.zeros(1),
|
||||
"total": torch.zeros(1),
|
||||
}
|
||||
|
||||
# trainable tokens for throughput and total token slots for summaries
|
||||
self.state.tokens["trainable"] = (
|
||||
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
|
||||
)
|
||||
self.state.tokens["total"] = self.state.tokens["total"] + total_tokens.cpu()
|
||||
# Store per-step trainable tokens for throughput calculation
|
||||
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()
|
||||
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
@@ -603,6 +616,7 @@ class AxolotlTrainer(
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
|
||||
|
||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
|
||||
@@ -613,7 +627,18 @@ class AxolotlTrainer(
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max, sum]"
|
||||
)
|
||||
logs[key] = round(fn(values).item(), 4)
|
||||
logs[key] = round(fn(values).item(), metric_ndigits)
|
||||
|
||||
if "loss" in logs:
|
||||
try:
|
||||
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
|
||||
except OverflowError:
|
||||
logs["ppl"] = float("inf")
|
||||
if "eval_loss" in logs:
|
||||
try:
|
||||
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
|
||||
except OverflowError:
|
||||
logs["eval_ppl"] = float("inf")
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
@@ -625,13 +650,21 @@ class AxolotlTrainer(
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
if self.args.include_tkps and train_eval == "train":
|
||||
if (
|
||||
self.args.include_tkps
|
||||
and train_eval == "train"
|
||||
and hasattr(self.state, "tokens")
|
||||
):
|
||||
# each rank will log its own tokens per second
|
||||
# for logging_steps > 1 we obtain a moving average of this metric
|
||||
logs["tokens_per_second_per_gpu"] = round(
|
||||
logs["tokens/train_per_sec_per_gpu"] = round(
|
||||
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
|
||||
)
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
if (
|
||||
hasattr(self.state, "total_tokens")
|
||||
and self.state.total_tokens is not None
|
||||
):
|
||||
logs["total_tokens"] = int(self.state.total_tokens.item())
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
@@ -666,6 +699,19 @@ class AxolotlTrainer(
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save total_tokens state if tracking is enabled
|
||||
if self.args.include_tkps and hasattr(self.state, "tokens"):
|
||||
tokens_state = {
|
||||
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
|
||||
"trainable": int(
|
||||
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
|
||||
),
|
||||
}
|
||||
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
|
||||
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokens_state, f)
|
||||
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
|
||||
@@ -36,4 +36,6 @@ class DPOStrategy:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
if cfg.dpo_use_liger_kernel is not None:
|
||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -54,6 +54,8 @@ plugins:
|
||||
- granitemoehybrid
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- internvl
|
||||
- kimi_linear
|
||||
- lfm2
|
||||
- lfm2_moe
|
||||
- lfm2_vl
|
||||
@@ -61,6 +63,8 @@ plugins:
|
||||
- llama4
|
||||
- llama4_text
|
||||
- llava
|
||||
- ministral
|
||||
- ministral3
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@5eff953"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
|
||||
)
|
||||
|
||||
|
||||
@@ -96,7 +96,11 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
)
|
||||
|
||||
# The patch checks model_type internally
|
||||
cce_patch(cfg.model_config_type)
|
||||
|
||||
cce_patch(
|
||||
cfg.model_config_type,
|
||||
remote_model_id=cfg.base_model if cfg.trust_remote_code else None,
|
||||
)
|
||||
|
||||
def patch_llama_like(
|
||||
self,
|
||||
@@ -107,7 +111,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
"""
|
||||
from cut_cross_entropy.transformers.patch import PATCH_FNS
|
||||
|
||||
def patch_generic(maybe_model, patch_options, model_type: str):
|
||||
def patch_generic(
|
||||
maybe_model, patch_options, model_type: str, remote_model_id: str | None
|
||||
):
|
||||
import cut_cross_entropy.transformers.llama
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
|
||||
if cfg.dense_mixer:
|
||||
if not importlib.util.find_spec("densemixer"):
|
||||
raise RuntimeError(
|
||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
||||
"DenseMixer is not installed. Install it with `pip install densemixer`"
|
||||
)
|
||||
|
||||
from densemixer.patching import (
|
||||
|
||||
@@ -179,8 +179,17 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
# let subclasses add fields before transform
|
||||
tokenized_prompt = self._prepare_kd_fields(tokenized_prompt, prompt)
|
||||
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
return tokenized_prompt
|
||||
|
||||
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
||||
"""
|
||||
Hook for subclasses to prepare additional KD fields before transform
|
||||
"""
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
@@ -283,14 +292,13 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
target_token_ids = prompt.get("target_token_ids", None)
|
||||
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
def _prepare_kd_fields(self, tokenized_prompt, original_prompt):
|
||||
"""
|
||||
Add pre-tokenized target_token_ids for v2 format
|
||||
"""
|
||||
target_token_ids = original_prompt.pop("target_token_ids", None)
|
||||
if target_token_ids is not None:
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
KD trainer
|
||||
"""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
|
||||
@@ -60,6 +62,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
if columns_to_add:
|
||||
self._signature_columns += columns_to_add
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self,
|
||||
model,
|
||||
@@ -79,10 +82,22 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
||||
):
|
||||
del inputs["attention_mask"]
|
||||
|
||||
if num_items_in_batch is None and "labels" in inputs:
|
||||
num_items_in_batch = (inputs["labels"] != -100).sum().item()
|
||||
|
||||
if self.model_accepts_loss_kwargs:
|
||||
loss_kwargs = {}
|
||||
if num_items_in_batch is not None:
|
||||
loss_kwargs["num_items_in_batch"] = num_items_in_batch
|
||||
inputs = {**inputs, **loss_kwargs}
|
||||
|
||||
outputs = model(**inputs)
|
||||
return outputs[0]
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
loss = outputs["loss"]
|
||||
elif isinstance(outputs, tuple):
|
||||
loss = outputs[0]
|
||||
else:
|
||||
loss = outputs.loss if hasattr(outputs, "loss") else outputs
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
@@ -142,9 +142,12 @@ def load_lora(
|
||||
):
|
||||
setup_quantized_meta_for_peft(model)
|
||||
|
||||
model_kwargs: Any = {}
|
||||
if cfg.peft_autocast_adapter_dtype is not None:
|
||||
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
model_kwargs: Any = {}
|
||||
if cfg.lora_on_cpu:
|
||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||
model_kwargs["device_map"] = {"": "cpu"}
|
||||
@@ -155,7 +158,7 @@ def load_lora(
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, lora_config)
|
||||
model = get_peft_model(model, lora_config, **model_kwargs)
|
||||
|
||||
if rank == 0:
|
||||
try:
|
||||
|
||||
@@ -26,6 +26,48 @@ PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
class PatchManager:
|
||||
"""Manages the application of patches during the model loading process."""
|
||||
|
||||
@staticmethod
|
||||
def apply_pre_config_load_patches(cfg: DictDefault):
|
||||
"""
|
||||
Apply patches that must be set up before config loading.
|
||||
This is for patches that intercept remote code loading from HuggingFace,
|
||||
which needs to be in place before AutoConfig.from_pretrained() is called.
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary with model and training settings.
|
||||
"""
|
||||
if (
|
||||
hasattr(cfg, "base_model_config")
|
||||
and cfg.base_model_config
|
||||
and "kimi-linear" in cfg.base_model_config.lower()
|
||||
):
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_config,
|
||||
)
|
||||
|
||||
patch_kimi_config()
|
||||
|
||||
@staticmethod
|
||||
def apply_pre_tokenizer_load_patches(cfg: DictDefault):
|
||||
"""
|
||||
Apply patches that must be set up before tokenizer loading.
|
||||
This is for patches that intercept remote code loading from HuggingFace,
|
||||
which needs to be in place before AutoTokenizer.from_pretrained() is called.
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary with model and training settings.
|
||||
"""
|
||||
if (
|
||||
hasattr(cfg, "tokenizer_config")
|
||||
and cfg.tokenizer_config
|
||||
and "kimi-linear" in cfg.tokenizer_config.lower()
|
||||
):
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_tokenizer,
|
||||
)
|
||||
|
||||
patch_kimi_tokenizer()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg: DictDefault,
|
||||
@@ -157,12 +199,6 @@ class PatchManager:
|
||||
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.core.attention.flex_block_mask import (
|
||||
patch_create_causal_mask,
|
||||
)
|
||||
|
||||
patch_create_causal_mask(self.cfg.model_config_type)
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
@@ -190,6 +226,13 @@ class PatchManager:
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.model_config_type == "kimi_linear":
|
||||
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
|
||||
patch_kimi_model,
|
||||
)
|
||||
|
||||
patch_kimi_model()
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
@@ -124,6 +124,11 @@ def modify_tokenizer_files(
|
||||
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
"""Load and configure the tokenizer based on the provided config."""
|
||||
|
||||
# Apply patches that need to be in place before tokenizer loading
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
|
||||
PatchManager.apply_pre_tokenizer_load_patches(cfg)
|
||||
|
||||
def _load_mistral_common_tokenizer(cfg: DictDefault):
|
||||
"""Load mistral-common tokenizer"""
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
@@ -79,7 +79,11 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
and hasattr(model_config, "vision_config")
|
||||
and hasattr(model_config.vision_config, "image_size")
|
||||
):
|
||||
cfg.image_size = model_config.vision_config.image_size
|
||||
image_size = model_config.vision_config.image_size
|
||||
if isinstance(image_size, list):
|
||||
cfg.image_size = tuple(image_size)
|
||||
else:
|
||||
cfg.image_size = image_size
|
||||
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
|
||||
|
||||
quant_config_exists = (
|
||||
|
||||
@@ -75,3 +75,33 @@ def patch_parallelism_config():
|
||||
|
||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||
|
||||
|
||||
def patch_prepare_cp():
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def patched_prepare_cp(self, *args):
|
||||
if self.parallelism_config.cp_backend == "deepspeed":
|
||||
return args
|
||||
|
||||
from accelerate.big_modeling import _attach_context_parallel_hooks
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(
|
||||
context_parallel, mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
|
||||
return args
|
||||
|
||||
Accelerator._prepare_cp = patched_prepare_cp
|
||||
|
||||
148
src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py
Normal file
148
src/axolotl/monkeypatch/models/kimi_linear/configuration_kimi.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Kimi-Linear configuration.
|
||||
|
||||
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/configuration_kimi.py
|
||||
Revision: 6e163f3
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class KimiLinearConfig(PretrainedConfig):
|
||||
model_type = "kimi_linear"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type="kimi_linear",
|
||||
vocab_size=163840,
|
||||
hidden_size=4096,
|
||||
head_dim=None,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
tie_word_embeddings=False,
|
||||
moe_intermediate_size: Optional[int] = None,
|
||||
moe_renormalize: bool = True,
|
||||
moe_router_activation_func: str = "sigmoid",
|
||||
num_experts: Optional[int] = None,
|
||||
num_experts_per_token: Optional[int] = None,
|
||||
num_shared_experts: int = 0,
|
||||
routed_scaling_factor: float = 1.0,
|
||||
first_k_dense_replace: int = 0,
|
||||
moe_layer_freq: int = 1,
|
||||
use_grouped_topk: bool = True,
|
||||
num_expert_group: int = 1,
|
||||
topk_group: int = 1,
|
||||
q_lora_rank: Optional[int] = None,
|
||||
kv_lora_rank: Optional[int] = None,
|
||||
qk_nope_head_dim: Optional[int] = None,
|
||||
qk_rope_head_dim: Optional[int] = None,
|
||||
v_head_dim: Optional[int] = None,
|
||||
mla_use_nope: Optional[bool] = False,
|
||||
num_nextn_predict_layers: int = 0,
|
||||
linear_attn_config: Optional[dict] = None,
|
||||
router_aux_loss_coef: float = 0.01,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.head_dim = (
|
||||
head_dim if head_dim is not None else hidden_size // num_attention_heads
|
||||
)
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.mla_use_nope = mla_use_nope
|
||||
# moe config
|
||||
self.num_experts = num_experts
|
||||
self.num_experts_per_token = num_experts_per_token
|
||||
self.moe_renormalize = moe_renormalize
|
||||
self.num_shared_experts = num_shared_experts
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.moe_router_activation_func = moe_router_activation_func
|
||||
assert self.moe_router_activation_func in ("softmax", "sigmoid")
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
|
||||
if linear_attn_config is not None:
|
||||
assert linear_attn_config["kda_layers"] is not None
|
||||
assert linear_attn_config["full_attn_layers"] is not None
|
||||
self.linear_attn_config = linear_attn_config
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_mla(self):
|
||||
return (
|
||||
self.q_lora_rank is not None
|
||||
or self.kv_lora_rank is not None
|
||||
or self.qk_nope_head_dim is not None
|
||||
or self.qk_rope_head_dim is not None
|
||||
or self.v_head_dim is not None
|
||||
or self.mla_use_nope is True
|
||||
)
|
||||
|
||||
@property
|
||||
def is_moe(self):
|
||||
return self.num_experts is not None
|
||||
|
||||
@property
|
||||
def is_linear_attn(self) -> bool:
|
||||
return not (
|
||||
self.linear_attn_config is None
|
||||
or (
|
||||
isinstance(self.linear_attn_config, dict)
|
||||
and self.linear_attn_config["kda_layers"] is not None
|
||||
and len(self.linear_attn_config["kda_layers"]) == 0
|
||||
)
|
||||
)
|
||||
|
||||
def is_kda_layer(self, layer_idx: int):
|
||||
return (
|
||||
self.linear_attn_config is not None
|
||||
and (layer_idx + 1) in self.linear_attn_config["kda_layers"]
|
||||
)
|
||||
1361
src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py
Normal file
1361
src/axolotl/monkeypatch/models/kimi_linear/modeling_kimi.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,85 @@
|
||||
import importlib.resources
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
KIMI_PATCH_PACKAGE = "axolotl.monkeypatch.models.kimi_linear"
|
||||
|
||||
|
||||
def get_patch_file_path(package_dot_path: str, filename: str) -> Path:
|
||||
"""
|
||||
Gets the absolute path to a patch file using importlib.resources.files.
|
||||
"""
|
||||
try:
|
||||
return importlib.resources.files(package_dot_path) / filename
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_local_module(module_name: str, filename: str):
|
||||
"""Helper to load a local module if not already loaded."""
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
|
||||
patch_path = get_patch_file_path(KIMI_PATCH_PACKAGE, filename)
|
||||
if patch_path and patch_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(module_name, patch_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
return None
|
||||
|
||||
|
||||
def _patch_get_class_in_module():
|
||||
"""
|
||||
Core patch function that hijacks Transformers' dynamic module loading.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import get_class_in_module
|
||||
|
||||
if hasattr(get_class_in_module, "_axolotl_patched"):
|
||||
return
|
||||
|
||||
original_get_class_in_module = get_class_in_module
|
||||
|
||||
# Mapping of module path patterns to (module_name, filename)
|
||||
KIMI_MODULE_MAP = {
|
||||
"configuration_kimi": ("configuration_kimi", "configuration_kimi.py"),
|
||||
"modeling_kimi": ("modeling_kimi", "modeling_kimi.py"),
|
||||
"tokenization_kimi": ("tokenization_kimi", "tokenization_kimi.py"),
|
||||
}
|
||||
|
||||
def patched_get_class_in_module(class_name, module_path, **kwargs):
|
||||
"""Patched version that returns our local modules instead of remote ones."""
|
||||
for pattern, (module_name, filename) in KIMI_MODULE_MAP.items():
|
||||
if pattern in module_path:
|
||||
module = _load_local_module(module_name, filename)
|
||||
if module:
|
||||
return getattr(module, class_name)
|
||||
break # Pattern matched but file not found, fall through
|
||||
|
||||
return original_get_class_in_module(class_name, module_path, **kwargs)
|
||||
|
||||
import transformers.dynamic_module_utils
|
||||
|
||||
transformers.dynamic_module_utils.get_class_in_module = patched_get_class_in_module
|
||||
patched_get_class_in_module._axolotl_patched = True
|
||||
|
||||
|
||||
def patch_kimi():
|
||||
"""
|
||||
Apply all Kimi patches.
|
||||
Must be called BEFORE loading config/tokenizer/model.
|
||||
"""
|
||||
_patch_get_class_in_module()
|
||||
LOG.info("Kimi patches applied successfully!")
|
||||
|
||||
|
||||
# Keep these for backward compatibility if needed
|
||||
patch_kimi_config = patch_kimi
|
||||
patch_kimi_tokenizer = patch_kimi
|
||||
patch_kimi_model = patch_kimi
|
||||
357
src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py
Normal file
357
src/axolotl/monkeypatch/models/kimi_linear/tokenization_kimi.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Adapted Kimi-Linear tokenizer to use proper template defaults and misc fixes.
|
||||
|
||||
Source: https://huggingface.co/moonshotai/Kimi-Linear-48B-A3B-Instruct/blob/main/tokenization_kimi.py
|
||||
Revision: 919416f
|
||||
"""
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
from tokenizers import AddedToken
|
||||
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = getLogger(__name__)
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
|
||||
|
||||
|
||||
class TikTokenTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
||||
this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
The path to the Tiktoken model file.
|
||||
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
|
||||
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
|
||||
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
|
||||
The end of sequence token.
|
||||
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead. The second to last item in special_tokens.
|
||||
pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
additional_special_tokens (list of `str`, *optional*):
|
||||
A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
|
||||
skipped when decoding if `skip_special_tokens` is set to `True`.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
special_tokens: Dict[str, int]
|
||||
|
||||
num_reserved_special_tokens = 256
|
||||
|
||||
pat_str = "|".join(
|
||||
[
|
||||
r"""[\p{Han}]+""",
|
||||
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||
r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
|
||||
r"""\p{N}{1,3}""",
|
||||
r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
|
||||
r"""\s*[\r\n]+""",
|
||||
r"""\s+(?!\S)""",
|
||||
r"""\s+""",
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
bos_token: Union[str, AddedToken] = "[BOS]", # nosec: B107
|
||||
eos_token: Union[str, AddedToken] = "[EOS]", # nosec: B107
|
||||
unk_token: Union[str, AddedToken, None] = None,
|
||||
pad_token: Union[str, AddedToken, None] = None,
|
||||
additional_special_tokens: List[str] = None,
|
||||
added_tokens_decoder: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert os.path.isfile(vocab_file), vocab_file
|
||||
|
||||
if additional_special_tokens is None:
|
||||
additional_special_tokens = [
|
||||
"<|im_end|>",
|
||||
"<|im_user|>",
|
||||
"<|im_assistant|>",
|
||||
"<|start_header_id|>",
|
||||
"<|end_header_id|>",
|
||||
"[EOT]",
|
||||
"<|im_system|>",
|
||||
"<|im_middle|>",
|
||||
]
|
||||
|
||||
special_tokens_mapping = {
|
||||
i: added_tokens_decoder[i].content for i in added_tokens_decoder
|
||||
}
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
mergeable_ranks = load_tiktoken_bpe(vocab_file)
|
||||
num_base_tokens = len(mergeable_ranks)
|
||||
self.special_tokens = {
|
||||
special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
|
||||
for i in range(
|
||||
num_base_tokens, num_base_tokens + self.num_reserved_special_tokens + 2
|
||||
)
|
||||
}
|
||||
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(vocab_file).name,
|
||||
pat_str=self.pat_str,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens,
|
||||
)
|
||||
logger.info(f"Reloaded tiktoken model from {vocab_file}")
|
||||
|
||||
self.n_words: int = self.model.n_vocab
|
||||
# BOS / EOS token IDs
|
||||
self.bos_id: int = self.special_tokens[str(bos_token)]
|
||||
self.eos_id: int = self.special_tokens[str(eos_token)]
|
||||
logger.info(
|
||||
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
||||
)
|
||||
|
||||
self.pad_id: int = self.special_tokens[str(pad_token)]
|
||||
self.unk_id: int = self.special_tokens[str(unk_token)]
|
||||
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
|
||||
self.decoder = {}
|
||||
for i in range(self.n_words):
|
||||
# Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
|
||||
decoding = "".join(
|
||||
[
|
||||
self.byte_encoder[ord(char)]
|
||||
for char in self.model.decode_single_token_bytes(i).decode(
|
||||
"latin-1"
|
||||
)
|
||||
]
|
||||
)
|
||||
self.decoder[i] = decoding
|
||||
|
||||
self.encoder = {}
|
||||
for i in range(self.n_words):
|
||||
if i in self.decoder:
|
||||
self.encoder[self.decoder[i]] = i
|
||||
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
self.all_special_ids_set = set(self.all_special_ids)
|
||||
|
||||
def encode(
|
||||
self, text: str, allow_special_tokens: bool = True, **kwargs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Encodes a string into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text (str): The input string to be encoded.
|
||||
|
||||
Returns:
|
||||
list[int]: A list of token IDs.
|
||||
"""
|
||||
# If there are other args, we should call super().encode because there are a lot of code
|
||||
# to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
|
||||
# NOTE: our encode method is not compatible with the super().encode method,
|
||||
# e.g. split_special_tokens' default is True in our encode method.
|
||||
if len(kwargs) > 0:
|
||||
# logger.warning(f"Calling super().encode with {kwargs}")
|
||||
return super().encode(text, **kwargs)
|
||||
|
||||
assert type(text) is str
|
||||
|
||||
# The tiktoken tokenizer can handle <=400k chars without
|
||||
# pyo3_runtime.PanicException.
|
||||
TIKTOKEN_MAX_ENCODE_CHARS = 400_000
|
||||
|
||||
# https://github.com/openai/tiktoken/issues/195
|
||||
# Here we iterate over subsequences and split if we exceed the limit
|
||||
# of max consecutive non-whitespace or whitespace characters.
|
||||
MAX_NO_WHITESPACES_CHARS = 25_000
|
||||
|
||||
texts = self.pre_tokenizer_process(text)
|
||||
|
||||
all_substrs = []
|
||||
for text in texts:
|
||||
substrs = (
|
||||
substr
|
||||
for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
|
||||
for substr in self._split_whitespaces_or_nonwhitespaces(
|
||||
text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
|
||||
)
|
||||
)
|
||||
all_substrs.extend(substrs)
|
||||
|
||||
t: List[int] = []
|
||||
for substr in all_substrs:
|
||||
if allow_special_tokens:
|
||||
t.extend(
|
||||
# we should consider special token as a common token
|
||||
self.model.encode(
|
||||
substr,
|
||||
allowed_special="all",
|
||||
)
|
||||
)
|
||||
else:
|
||||
t.extend(
|
||||
# we should consider special token as a common token
|
||||
self.model.encode(
|
||||
substr,
|
||||
disallowed_special=(),
|
||||
)
|
||||
)
|
||||
|
||||
return t
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs into a string.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The list of token IDs to be decoded.
|
||||
|
||||
Returns:
|
||||
str: The decoded string.
|
||||
"""
|
||||
# If there are other args, we should call super().decode because there are a lot of code
|
||||
# to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
|
||||
if len(kwargs) > 0:
|
||||
return super().decode(token_ids, **kwargs)
|
||||
|
||||
if type(token_ids) is int:
|
||||
token_ids = [token_ids]
|
||||
|
||||
return self.model.decode(cast(List[int], token_ids))
|
||||
|
||||
@staticmethod
|
||||
def _split_whitespaces_or_nonwhitespaces(
|
||||
s: str, max_consecutive_slice_len: int
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
|
||||
consecutive whitespaces or consecutive non-whitespaces.
|
||||
"""
|
||||
current_slice_len = 0
|
||||
current_slice_is_space = s[0].isspace() if len(s) > 0 else False
|
||||
slice_start = 0
|
||||
|
||||
for i in range(len(s)):
|
||||
is_now_space = s[i].isspace()
|
||||
|
||||
if current_slice_is_space ^ is_now_space:
|
||||
current_slice_len = 1
|
||||
current_slice_is_space = is_now_space
|
||||
else:
|
||||
current_slice_len += 1
|
||||
if current_slice_len > max_consecutive_slice_len:
|
||||
yield s[slice_start:i]
|
||||
slice_start = i
|
||||
current_slice_len = 1
|
||||
yield s[slice_start:]
|
||||
|
||||
def pre_tokenizer_process(self, text: str) -> List[str]:
|
||||
"""
|
||||
pre-tokenizes the input text into a list of tokens.
|
||||
This method is used to split the input text into smaller chunks for internal processing.
|
||||
"""
|
||||
return [text]
|
||||
|
||||
""" ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.n_words
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self.encoder
|
||||
|
||||
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
return [self.decoder[t] for t in self.encode(text)]
|
||||
|
||||
def _convert_token_to_id(self, token: str) -> int:
|
||||
return self.encoder.get(token, self.unk_id)
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
return self.decoder.get(index)
|
||||
|
||||
@staticmethod
|
||||
def clean_up_tokenization(out_string: str) -> str:
|
||||
return out_string
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
text = "".join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
"utf-8", "replace"
|
||||
)
|
||||
return text
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory: str, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
if not os.path.isdir(save_directory):
|
||||
raise ValueError(
|
||||
f"vocabulary path ({save_directory}) should be a directory"
|
||||
)
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "")
|
||||
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file
|
||||
) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
conversation,
|
||||
tools: Optional[list[dict]] = None,
|
||||
tokenize: bool = True,
|
||||
add_generation_prompt: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
tools = deep_sort_dict(tools)
|
||||
return super().apply_chat_template(
|
||||
conversation,
|
||||
tools=tools,
|
||||
tokenize=tokenize,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def deep_sort_dict(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
|
||||
if isinstance(obj, list):
|
||||
return [deep_sort_dict(item) for item in obj]
|
||||
return obj
|
||||
@@ -52,6 +52,9 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"olmo",
|
||||
"olmo2",
|
||||
"olmo3",
|
||||
"ministral",
|
||||
"ministral3",
|
||||
"afmoe",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_utils import load_image
|
||||
from transformers.models.internvl import InternVLProcessor
|
||||
from transformers.models.smolvlm import SmolVLMProcessor
|
||||
from transformers.models.voxtral import VoxtralProcessor
|
||||
|
||||
@@ -454,6 +455,37 @@ class Mistral3ProcessingStrategy(ProcessingStrategy):
|
||||
return labels
|
||||
|
||||
|
||||
class InternVLProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for InternVL"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
|
||||
if not hasattr(processor, "image_ids"):
|
||||
raise ValueError("'image_ids' missing from InternVL Processor.")
|
||||
|
||||
self.image_token_ids = processor.image_ids
|
||||
|
||||
def process_labels(self, input_ids):
|
||||
labels = input_ids.clone()
|
||||
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
for ids in self.image_token_ids:
|
||||
labels[labels == ids] = -100
|
||||
|
||||
# Note: Check if need to mask 'video_token' as it gets converted to
|
||||
# image patches during media processing
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -501,6 +533,11 @@ def get_processing_strategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, InternVLProcessor):
|
||||
return InternVLProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user