Compare commits
6 Commits
v0.11.0.po
...
codecov-pu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3c8a25b30 | ||
|
|
016eb8055f | ||
|
|
639ddeff6a | ||
|
|
753e4e3dec | ||
|
|
2538c3b761 | ||
|
|
aa3639b7ad |
2
.bandit
2
.bandit
@@ -1,3 +1,3 @@
|
|||||||
[bandit]
|
[bandit]
|
||||||
exclude = tests
|
exclude = tests
|
||||||
skips = B101,B615
|
skips = B101
|
||||||
|
|||||||
14
.github/workflows/base.yml
vendored
14
.github/workflows/base.yml
vendored
@@ -5,13 +5,11 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- 'docker/Dockerfile-base'
|
- 'Dockerfile-base'
|
||||||
- 'docker/Dockerfile-uv-base'
|
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'docker/Dockerfile-base'
|
- 'Dockerfile-base'
|
||||||
- 'docker/Dockerfile-uv-base'
|
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
@@ -29,11 +27,11 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.5.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "126"
|
- cuda: "124"
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
@@ -43,7 +41,7 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "126"
|
- cuda: "126"
|
||||||
|
|||||||
33
.github/workflows/main.yml
vendored
33
.github/workflows/main.yml
vendored
@@ -15,16 +15,17 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
|
is_latest: true
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -82,17 +83,17 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.0
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -145,8 +146,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
11
.github/workflows/multi-gpu-e2e.yml
vendored
11
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,10 +26,17 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras: vllm
|
||||||
|
num_gpus: 2
|
||||||
|
nightly_build: "true"
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
11
.github/workflows/nightlies.yml
vendored
11
.github/workflows/nightlies.yml
vendored
@@ -12,6 +12,11 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.1
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -63,10 +68,10 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
123
.github/workflows/tests-nightly.yml
vendored
123
.github/workflows/tests-nightly.yml
vendored
@@ -18,26 +18,116 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
|
preload-cache:
|
||||||
|
name: Preload HF cache
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.11"]
|
||||||
|
pytorch_version: ["2.6.0"]
|
||||||
|
timeout-minutes: 20
|
||||||
|
|
||||||
|
env:
|
||||||
|
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Restore HF cache
|
||||||
|
id: hf-cache-restore
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python_version }}
|
||||||
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|
||||||
|
- name: upgrade pip
|
||||||
|
run: |
|
||||||
|
pip3 install --upgrade pip
|
||||||
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
|
|
||||||
|
- name: Install PyTorch
|
||||||
|
run: |
|
||||||
|
pip3 install torch==${{ matrix.pytorch_version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip3 show torch
|
||||||
|
pip3 install --no-build-isolation -U -e .
|
||||||
|
python scripts/unsloth_install.py | sh
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
|
run: |
|
||||||
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
|
|
||||||
|
- name: Ensure axolotl CLI was installed
|
||||||
|
run: |
|
||||||
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Pre-Download dataset fixture
|
||||||
|
run: |
|
||||||
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
pytest -v tests/conftest.py
|
||||||
|
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v5
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
files: ./coverage.xml
|
||||||
|
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||||
|
fail_ci_if_error: false
|
||||||
|
|
||||||
|
- name: cleanup pip cache
|
||||||
|
run: |
|
||||||
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
|
- name: Save HF cache
|
||||||
|
id: hf-cache
|
||||||
|
uses: actions/cache/save@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
|
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
needs: [preload-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.6.0", "2.7.0"]
|
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore Cache from S3
|
- name: Restore HF cache
|
||||||
id: hf-cache-restore-s3
|
id: hf-cache-restore
|
||||||
run: |
|
uses: actions/cache/restore@v4
|
||||||
mkdir -p /home/runner/.cache/huggingface/hub
|
with:
|
||||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
path: |
|
||||||
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -78,11 +168,15 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Pre-Download dataset fixture
|
||||||
|
run: |
|
||||||
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v --durations=10 tests/patched/
|
pytest -v tests/patched/
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -99,8 +193,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
nightly_build: "true"
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
|||||||
90
.github/workflows/tests.yml
vendored
90
.github/workflows/tests.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -102,17 +102,16 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage artifacts
|
||||||
uses: codecov/codecov-action@v5
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
name: coverage-${{ matrix.pytorch_version }}-${{ github.run_id }}
|
||||||
files: ./coverage.xml
|
path: ./coverage.xml
|
||||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
retention-days: 1
|
||||||
fail_ci_if_error: false
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -125,7 +124,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -175,9 +174,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v --durations=10 tests/patched/
|
pytest -v tests/patched/
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -195,12 +194,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 126
|
- cuda: 124
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -234,6 +233,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
|
- name: Upload coverage artifacts
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: coverage-e2e-1st-${{ github.run_id }}
|
||||||
|
path: ./e2e-coverage.xml
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
@@ -247,10 +254,22 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras: llmcompressor
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.5.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.7.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
@@ -285,6 +304,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
|
- name: Upload coverage artifacts
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: coverage-e2e-${{ matrix.cuda }}-${{ matrix.pytorch }}-${{ github.run_id }}
|
||||||
|
path: ./e2e-coverage.xml
|
||||||
|
retention-days: 1
|
||||||
|
|
||||||
docker-e2e-cleanup:
|
docker-e2e-cleanup:
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
@@ -299,7 +326,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -324,3 +351,26 @@ jobs:
|
|||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.cleanup
|
modal run cicd.cleanup
|
||||||
|
|
||||||
|
upload-coverage:
|
||||||
|
name: Upload Coverage to Codecov
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: [pytest, docker-e2e-tests, docker-e2e-tests-1st]
|
||||||
|
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main'
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download coverage reports
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: coverage-reports
|
||||||
|
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v5
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
directory: coverage-reports
|
||||||
|
fail_ci_if_error: false
|
||||||
|
verbose: true
|
||||||
|
name: codecov-umbrella
|
||||||
|
override_commit: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||||
|
override_pr: ${{ github.event.pull_request.number }}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.3.0
|
rev: 7.2.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
@@ -27,7 +27,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.1
|
rev: v1.16.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.6
|
rev: 1.8.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
@@ -2,5 +2,4 @@ include requirements.txt
|
|||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
|
||||||
recursive-include axolotl *.py
|
recursive-include axolotl *.py
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -43,7 +43,7 @@ Features:
|
|||||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
||||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||||
|
|
||||||
@@ -55,12 +55,10 @@ Features:
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python 3.11
|
||||||
- PyTorch ≥2.6.0
|
- PyTorch ≥2.5.1
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
#### Using pip
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
@@ -70,13 +68,6 @@ axolotl fetch examples
|
|||||||
axolotl fetch deepspeed_configs # OPTIONAL
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using Docker
|
|
||||||
|
|
||||||
Installing with Docker can be less error prone than installing in your own environment.
|
|
||||||
```bash
|
|
||||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
|
||||||
```
|
|
||||||
|
|
||||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
### Your First Fine-tune
|
### Your First Fine-tune
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
|||||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||||
ENV HF_HOME="{{ HF_HOME }}"
|
ENV HF_HOME="{{ HF_HOME }}"
|
||||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|||||||
@@ -51,5 +51,3 @@ pytest -v --durations=10 \
|
|||||||
--cov=axolotl \
|
--cov=axolotl \
|
||||||
--cov-append \
|
--cov-append \
|
||||||
--cov-report=xml:e2e-coverage.xml
|
--cov-report=xml:e2e-coverage.xml
|
||||||
|
|
||||||
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Modal app to run axolotl GPU tests"""
|
"""Modal app to run axolotl GPU tests"""
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
|
||||||
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||||
|
|
||||||
|
|
||||||
@@ -12,9 +14,21 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
|||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
)
|
)
|
||||||
def cicd_pytest():
|
def cicd_pytest():
|
||||||
|
|
||||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
# Read the coverage file if it exists
|
||||||
|
coverage_file = pathlib.Path("/workspace/axolotl/e2e-coverage.xml")
|
||||||
|
if coverage_file.exists():
|
||||||
|
return coverage_file.read_text(encoding="utf-8")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
coverage = cicd_pytest.remote()
|
||||||
|
|
||||||
|
# Save the coverage file to the local filesystem if it was generated
|
||||||
|
if coverage:
|
||||||
|
with open("e2e-coverage.xml", "w", encoding="utf-8") as f:
|
||||||
|
f.write(coverage)
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
|||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||||
"CUDA": os.environ.get("CUDA", "126"),
|
"CUDA": os.environ.get("CUDA", "124"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
@@ -77,7 +77,18 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
def cicd_pytest():
|
def cicd_pytest():
|
||||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
# Read the coverage file if it exists
|
||||||
|
coverage_file = pathlib.Path("/workspace/axolotl/multigpu-coverage.xml")
|
||||||
|
if coverage_file.exists():
|
||||||
|
return coverage_file.read_text(encoding="utf-8")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
coverage = cicd_pytest.remote()
|
||||||
|
|
||||||
|
# Save the coverage file to the local filesystem if it was generated
|
||||||
|
if coverage:
|
||||||
|
with open("multigpu-coverage.xml", "w", encoding="utf-8") as file:
|
||||||
|
file.write(coverage)
|
||||||
|
|||||||
@@ -24,16 +24,14 @@ df_template = template_env.get_template(dockerfile)
|
|||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||||
"CUDA": os.environ.get("CUDA", "126"),
|
"CUDA": os.environ.get("CUDA", "124"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
|
||||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerfile_contents = df_template.render(**df_args)
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
pip3 install flash-attn==2.7.4.post1; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -34,3 +34,7 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
|
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||||
|
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
||||||
|
fi
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ toc-depth: 3
|
|||||||
```{python}
|
```{python}
|
||||||
#| echo: false
|
#| echo: false
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
def process_readme(integration_name):
|
def process_readme(integration_name):
|
||||||
@@ -54,24 +53,6 @@ sections = [
|
|||||||
("LLMCompressor", "llm_compressor")
|
("LLMCompressor", "llm_compressor")
|
||||||
]
|
]
|
||||||
|
|
||||||
for folder_name in os.listdir("../src/axolotl/integrations/"):
|
|
||||||
if folder_name in [path for name, path in sections]:
|
|
||||||
# skip if already in sections
|
|
||||||
continue
|
|
||||||
if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"):
|
|
||||||
# grab the first heading in README.md as the section name
|
|
||||||
with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f:
|
|
||||||
txt = f.read()
|
|
||||||
matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE)
|
|
||||||
if matches:
|
|
||||||
name = matches.group(1)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
sections.append((name, folder_name))
|
|
||||||
|
|
||||||
# sort sections by name
|
|
||||||
sections = sorted(sections, key=lambda x: x[0])
|
|
||||||
|
|
||||||
for section_name, folder_name in sections:
|
for section_name, folder_name in sections:
|
||||||
print(print_section(section_name, folder_name))
|
print(print_section(section_name, folder_name))
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ order: 3
|
|||||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
```{.json filename="data.jsonl"}
|
||||||
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
{"conversations": [{"role": "...", "content": "..."}]}
|
||||||
```
|
```
|
||||||
|
|
||||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ format:
|
|||||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
@@ -34,8 +34,8 @@ Tags examples:
|
|||||||
|
|
||||||
- `main-base-py3.11-cu128-2.7.1`
|
- `main-base-py3.11-cu128-2.7.1`
|
||||||
- `main-base-py3.11-cu126-2.7.1`
|
- `main-base-py3.11-cu126-2.7.1`
|
||||||
- `main-base-py3.11-cu126-2.6.0`
|
|
||||||
- `main-base-py3.11-cu124-2.6.0`
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
|
- `main-base-py3.11-cu124-2.5.1`
|
||||||
|
|
||||||
## Main
|
## Main
|
||||||
|
|
||||||
@@ -73,14 +73,13 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu128-2.7.1`
|
- `main-py3.11-cu126-2.7.0`
|
||||||
- `main-py3.11-cu126-2.7.1`
|
|
||||||
- `main-py3.11-cu126-2.6.0`
|
|
||||||
- `main-py3.11-cu124-2.6.0`
|
- `main-py3.11-cu124-2.6.0`
|
||||||
|
- `main-py3.11-cu124-2.5.1`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu126-2.6.0`
|
- `main-20250303-py3.11-cu124-2.5.1`
|
||||||
- `0.10.1`
|
- `0.9.2`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
16
docs/faq.qmd
16
docs/faq.qmd
@@ -9,11 +9,11 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
||||||
|
|
||||||
**Q: exitcode: -9**
|
**Q: Exitcode -9**
|
||||||
|
|
||||||
> A: This usually happens when you run out of system RAM.
|
> A: This usually happens when you run out of system RAM.
|
||||||
|
|
||||||
**Q: exitcode: -7 while using deepspeed**
|
**Q: Exitcode -7 while using deepspeed**
|
||||||
|
|
||||||
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||||
|
|
||||||
@@ -51,18 +51,6 @@ description: Frequently asked questions
|
|||||||
> pad_token: "..."
|
> pad_token: "..."
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
|
|
||||||
|
|
||||||
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
|
|
||||||
|
|
||||||
**Q: vLLM is not working with Axolotl**
|
|
||||||
|
|
||||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
|
||||||
|
|
||||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
|
||||||
|
|
||||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||||
|
|
||||||
1. Set `adapter: qlora` in your axolotl config file.
|
1. Set `adapter: qlora` in your axolotl config file.
|
||||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||||
|
|
||||||
## Example Config
|
## Example Config
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.11
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.6.0
|
- PyTorch ≥2.5.1
|
||||||
|
|
||||||
## Installation Methods {#sec-installation-methods}
|
## Installation Methods {#sec-installation-methods}
|
||||||
|
|
||||||
|
|||||||
@@ -66,15 +66,6 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
|
|
||||||
Using ZeRO Stage 3 with Single-GPU training
|
|
||||||
|
|
||||||
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
|
|
||||||
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
## FSDP {#sec-fsdp}
|
## FSDP {#sec-fsdp}
|
||||||
|
|
||||||
### Basic FSDP Configuration {#sec-fsdp-config}
|
### Basic FSDP Configuration {#sec-fsdp-config}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,69 +0,0 @@
|
|||||||
# Finetune Devstral with Axolotl
|
|
||||||
|
|
||||||
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
|
||||||
|
|
||||||
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of upto 128k tokens.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install the latest mistral-common from source
|
|
||||||
pip3 uninstall mistral-common
|
|
||||||
pip3 install git+https://github.com/mistralai/mistral-common.git@039465d
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/devstral/devstral-small-qlora.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 21GB VRAM.
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
|
||||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
|
||||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
|
||||||
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
|
||||||
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
|
|
||||||
|
|
||||||
## Limitations
|
|
||||||
|
|
||||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
|
||||||
|
|
||||||
In addition, we do not support overriding tokens yet.
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
|
|
||||||
|
|
||||||
## Future Work
|
|
||||||
|
|
||||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
|
||||||
- Add parity to other tokenizer configs like overriding tokens.
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
base_model: mistralai/Devstral-Small-2505
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
# Enable to use mistral-common tokenizer
|
|
||||||
tokenizer_use_mistral_common: true
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/qlora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
|
||||||
loss_watchdog_patience: 3
|
|
||||||
|
|
||||||
warmup_ratio: 0.05
|
|
||||||
evals_per_epoch: 4
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: tiiuae/Falcon-H1-1.5B-Deep-Base
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: falcon_h1
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
- in_proj
|
|
||||||
- gate_proj
|
|
||||||
- up_proj
|
|
||||||
- down_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: tiiuae/Falcon-H1-1.5B-Base
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: falcon_h1
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
- in_proj
|
|
||||||
- gate_proj
|
|
||||||
- up_proj
|
|
||||||
- down_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: tiiuae/Falcon-H1-34B-Base
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: falcon_h1
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
- in_proj
|
|
||||||
- gate_proj
|
|
||||||
- up_proj
|
|
||||||
- down_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: tiiuae/Falcon-H1-3B-Base
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: falcon_h1
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
- in_proj
|
|
||||||
- gate_proj
|
|
||||||
- up_proj
|
|
||||||
- down_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: tiiuae/Falcon-H1-0.5B-Instruct
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: falcon_h1
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
- in_proj
|
|
||||||
- gate_proj
|
|
||||||
- up_proj
|
|
||||||
- down_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch:
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
base_model: tiiuae/Falcon-H1-7B-Base
|
|
||||||
# optionally might have model_type or tokenizer_type
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
# huggingface repo
|
|
||||||
chat_template: falcon_h1
|
|
||||||
datasets:
|
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
|
||||||
type: chat_template
|
|
||||||
field_messages: conversations
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- o_proj
|
|
||||||
- in_proj
|
|
||||||
- gate_proj
|
|
||||||
- up_proj
|
|
||||||
- down_proj
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: false
|
|
||||||
eval_sample_packing: false
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 4
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
gradient_checkpointing_kwargs:
|
|
||||||
use_reentrant: false
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
@@ -13,8 +13,6 @@ load_in_4bit: true
|
|||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
eot_tokens:
|
|
||||||
- <end_of_turn>
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ load_in_4bit: true
|
|||||||
ddp_find_unused_parameters: true
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
eot_tokens:
|
|
||||||
- <end_of_turn>
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ sample_packing: false
|
|||||||
ddp_find_unused_parameters: true
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
eot_tokens:
|
|
||||||
- <end_of_turn>
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
@@ -18,10 +18,16 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Download the example config:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl fetch examples
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||||
@@ -36,7 +42,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|
||||||
@@ -48,7 +54,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||||
|
|
||||||
In addition, we do not support overriding tokens yet.
|
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
|
|||||||
@@ -1,55 +0,0 @@
|
|||||||
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
|
||||||
processor_type: AutoProcessor
|
|
||||||
|
|
||||||
# these 3 lines are needed for now to handle vision chat templates w images
|
|
||||||
skip_prepare_dataset: true
|
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
chat_template: qwen2_vl
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
|
||||||
type: chat_template
|
|
||||||
split: train[:1%]
|
|
||||||
field_messages: messages
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
pad_to_sequence_len: false
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
fp16:
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
weight_decay: 0.0
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.46.0
|
bitsandbytes==0.45.4
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
@@ -13,12 +13,12 @@ packaging==23.2
|
|||||||
|
|
||||||
huggingface_hub==0.32.2
|
huggingface_hub==0.32.2
|
||||||
peft==0.15.2
|
peft==0.15.2
|
||||||
transformers==4.53.1
|
transformers==4.52.4
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.8.1
|
accelerate==1.7.0
|
||||||
datasets==3.6.0
|
datasets==3.6.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.18.2
|
trl==0.18.1
|
||||||
hf_xet==1.1.2
|
hf_xet==1.1.2
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
|||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.6.3
|
mistral-common==1.6.0
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
|
||||||
)
|
)
|
||||||
|
|||||||
13
setup.py
13
setup.py
@@ -66,11 +66,8 @@ def parse_requirements(extras_require_map):
|
|||||||
|
|
||||||
if (major, minor) >= (2, 7):
|
if (major, minor) >= (2, 7):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||||
_install_requires.append("xformers==0.0.30")
|
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||||
else:
|
|
||||||
_install_requires.append("xformers==0.0.31.post1")
|
|
||||||
extras_require_map["vllm"] = ["vllm>=0.9.0"]
|
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append(
|
_install_requires.append(
|
||||||
@@ -114,10 +111,10 @@ def get_package_version():
|
|||||||
|
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"flash-attn": ["flash-attn==2.8.0.post2"],
|
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.8.0.post2",
|
"flash-attn==2.7.4.post1",
|
||||||
"ring-flash-attn>=0.1.5",
|
"ring-flash-attn>=0.1.4",
|
||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.11.0"
|
__version__ = "0.11.0.dev"
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
|||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
from requests import HTTPError
|
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -47,8 +46,3 @@ def check_user_token() -> bool:
|
|||||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except HTTPError:
|
|
||||||
LOG.warning(
|
|
||||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -23,6 +24,7 @@ def do_cli_preprocess(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
print_axolotl_text_art()
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -37,6 +39,7 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
print_axolotl_text_art()
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -51,6 +54,7 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
print_axolotl_text_art()
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
|||||||
@@ -26,9 +26,7 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
|||||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
API_KEY_FIELDS = {"comet_api_key"}
|
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
@@ -235,15 +233,4 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
cfg_to_log = {
|
|
||||||
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
|
||||||
for k, v in cfg.items()
|
|
||||||
if v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
"config:\n%s",
|
|
||||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -34,6 +35,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from dotenv import load_dotenv
|
|||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.cli.args import InferenceCliArgs
|
from axolotl.cli.args import InferenceCliArgs
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -254,6 +255,7 @@ def do_cli(
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from axolotl.cli.args import (
|
|||||||
TrainerCliArgs,
|
TrainerCliArgs,
|
||||||
VllmServeCliArgs,
|
VllmServeCliArgs,
|
||||||
)
|
)
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -41,7 +40,6 @@ LOG = get_logger(__name__)
|
|||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import Union
|
|||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -22,6 +23,8 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
"""
|
"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -193,6 +194,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from dotenv import load_dotenv
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
from axolotl.cli.args import PreprocessCliArgs
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -32,15 +33,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Preprocessing-specific CLI arguments.
|
cli_args: Preprocessing-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
|
|
||||||
if cfg.get("key"):
|
|
||||||
raise ValueError(
|
|
||||||
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not cfg.dataset_prepared_path:
|
if not cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
Fore.RED
|
Fore.RED
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Union
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -26,6 +27,7 @@ def do_quantize(
|
|||||||
config (Union[Path, str]): The path to the config file
|
config (Union[Path, str]): The path to the config file
|
||||||
cli_args (dict): Additional command-line arguments
|
cli_args (dict): Additional command-line arguments
|
||||||
"""
|
"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -34,6 +35,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -75,17 +75,13 @@ def load_datasets(
|
|||||||
|
|
||||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||||
text_only = cli_args.debug_text_only if cli_args else False
|
text_only = cli_args.debug_text_only if cli_args else False
|
||||||
try:
|
train_samples = sample_dataset(train_dataset, num_examples)
|
||||||
train_samples = sample_dataset(train_dataset, num_examples)
|
check_dataset_labels(
|
||||||
check_dataset_labels(
|
train_samples,
|
||||||
train_samples,
|
tokenizer,
|
||||||
tokenizer,
|
num_examples=num_examples,
|
||||||
num_examples=num_examples,
|
text_only=text_only,
|
||||||
text_only=text_only,
|
)
|
||||||
)
|
|
||||||
except AttributeError:
|
|
||||||
# can't sample iterable datasets
|
|
||||||
pass
|
|
||||||
|
|
||||||
LOG.info("printing prompters...")
|
LOG.info("printing prompters...")
|
||||||
for prompter in prompters:
|
for prompter in prompters:
|
||||||
|
|||||||
@@ -1,162 +0,0 @@
|
|||||||
"""
|
|
||||||
monkeypatch for flex + packing
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from typing import Callable, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn.attention.flex_attention import BlockMask
|
|
||||||
from transformers import Cache, PretrainedConfig
|
|
||||||
from transformers.masking_utils import (
|
|
||||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
|
||||||
_preprocess_mask_arguments,
|
|
||||||
and_masks,
|
|
||||||
causal_mask_function,
|
|
||||||
or_masks,
|
|
||||||
)
|
|
||||||
from transformers.utils import is_torch_greater_or_equal
|
|
||||||
|
|
||||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
|
||||||
|
|
||||||
|
|
||||||
def create_causal_mask(
|
|
||||||
config: PretrainedConfig,
|
|
||||||
input_embeds: torch.Tensor,
|
|
||||||
attention_mask: torch.Tensor,
|
|
||||||
cache_position: torch.Tensor,
|
|
||||||
past_key_values: Optional[Cache],
|
|
||||||
or_mask_function: Optional[Callable] = None,
|
|
||||||
and_mask_function: Optional[Callable] = None,
|
|
||||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
|
||||||
"""
|
|
||||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
|
||||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
|
||||||
to what is needed in the `modeling_xxx.py` files).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (`PretrainedConfig`):
|
|
||||||
The model config.
|
|
||||||
input_embeds (`torch.Tensor`):
|
|
||||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
|
||||||
batch size, query length and dtype.
|
|
||||||
attention_mask (`torch.Tensor`, optional):
|
|
||||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
|
||||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
|
||||||
cache_position (`torch.Tensor`):
|
|
||||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
|
||||||
past_key_values (`Cache`, optional):
|
|
||||||
The past key values, if we use a cache.
|
|
||||||
or_mask_function (`Callable`, optional):
|
|
||||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
|
||||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
|
||||||
and_mask_function (`Callable`, optional):
|
|
||||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
|
||||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
|
||||||
"""
|
|
||||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
|
||||||
if (
|
|
||||||
past_key_values
|
|
||||||
and hasattr(past_key_values, "is_sliding")
|
|
||||||
and False in past_key_values.is_sliding
|
|
||||||
):
|
|
||||||
layer_idx = past_key_values.is_sliding.index(False)
|
|
||||||
else:
|
|
||||||
layer_idx = 0
|
|
||||||
|
|
||||||
original_attention_mask = (
|
|
||||||
None
|
|
||||||
if attention_mask is None
|
|
||||||
else attention_mask.clone().to(cache_position.device)
|
|
||||||
)
|
|
||||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
|
||||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
|
||||||
)
|
|
||||||
if early_exit:
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
batch_size, total_seq_len = cache_position.shape
|
|
||||||
key_length = total_seq_len
|
|
||||||
document_ids = torch.nn.functional.pad(
|
|
||||||
original_attention_mask, value=0, pad=(0, key_length)
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
|
||||||
if attention_mask is not None:
|
|
||||||
|
|
||||||
def causal_doc_mask_mod(
|
|
||||||
batch_idx, head_idx, q_idx, kv_idx
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
"""
|
|
||||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
|
||||||
and a block diagonal document mask.
|
|
||||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
|
||||||
for an illustration.
|
|
||||||
"""
|
|
||||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
|
||||||
document_mask = (
|
|
||||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
|
||||||
)
|
|
||||||
final_mask = causal_mask_ & document_mask
|
|
||||||
return final_mask
|
|
||||||
|
|
||||||
mask_factory_function = causal_doc_mask_mod
|
|
||||||
else:
|
|
||||||
mask_factory_function = causal_mask_function
|
|
||||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
|
|
||||||
config._attn_implementation # pylint: disable=protected-access
|
|
||||||
]
|
|
||||||
|
|
||||||
# Do not allow skip if we are compiling (this is to match BC)
|
|
||||||
allow_is_causal_skip = (
|
|
||||||
not past_key_values.is_compileable if past_key_values is not None else True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Allow slight deviations from causal mask
|
|
||||||
if or_mask_function is not None:
|
|
||||||
if not _is_torch_greater_or_equal_than_2_6:
|
|
||||||
raise ValueError(
|
|
||||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
|
||||||
)
|
|
||||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
|
||||||
allow_is_causal_skip = False
|
|
||||||
if and_mask_function is not None:
|
|
||||||
if not _is_torch_greater_or_equal_than_2_6:
|
|
||||||
raise ValueError(
|
|
||||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
|
||||||
)
|
|
||||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
|
||||||
allow_is_causal_skip = False
|
|
||||||
|
|
||||||
# We now create the mask
|
|
||||||
causal_mask = mask_interface(
|
|
||||||
batch_size=batch_size,
|
|
||||||
cache_position=cache_position,
|
|
||||||
kv_length=kv_length,
|
|
||||||
kv_offset=kv_offset,
|
|
||||||
mask_function=mask_factory_function,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
|
||||||
dtype=dtype, # Additional kwarg for eager
|
|
||||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
|
||||||
)
|
|
||||||
return causal_mask
|
|
||||||
|
|
||||||
|
|
||||||
def patch_create_causal_mask(model_type):
|
|
||||||
import transformers.masking_utils
|
|
||||||
|
|
||||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
|
||||||
|
|
||||||
if model_type:
|
|
||||||
try:
|
|
||||||
# Dynamically import the module and attention class
|
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
|
||||||
module = __import__(module_path)
|
|
||||||
module.create_causal_mask = create_causal_mask
|
|
||||||
del sys.modules[module_path]
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not import attention class for model_type: {model_type}. "
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
) from e
|
|
||||||
@@ -219,9 +219,7 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_args_kwargs["bf16_full_eval"] = True
|
training_args_kwargs["bf16_full_eval"] = True
|
||||||
else:
|
else:
|
||||||
bf16 = self.cfg.bf16 or self.cfg.bfloat16
|
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||||
bf16 = bf16 if bf16 is not None else False
|
|
||||||
training_args_kwargs["bf16"] = bf16
|
|
||||||
|
|
||||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||||
|
|||||||
@@ -245,27 +245,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
|
|
||||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
|
||||||
self.cfg.flash_attention
|
|
||||||
or self.cfg.xformers_attention
|
|
||||||
or self.cfg.flex_attention
|
|
||||||
)
|
|
||||||
training_arguments_kwargs["multipack_real_batches"] = (
|
training_arguments_kwargs["multipack_real_batches"] = (
|
||||||
self.cfg.multipack_real_batches
|
self.cfg.multipack_real_batches
|
||||||
if self.cfg.multipack_real_batches is not None
|
if self.cfg.multipack_real_batches is not None
|
||||||
else not (
|
else not self.cfg.flash_attention
|
||||||
self.cfg.flash_attention
|
|
||||||
or self.cfg.flex_attention
|
|
||||||
or self.cfg.xformers_attention
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
if self.cfg.sample_packing_sequentially is not None:
|
|
||||||
training_arguments_kwargs["sample_packing_sequentially"] = (
|
|
||||||
self.cfg.sample_packing_sequentially
|
|
||||||
)
|
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||||
self.cfg.sample_packing_bin_size
|
self.cfg.sample_packing_bin_size
|
||||||
@@ -426,8 +413,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
or self.cfg.micro_batch_size > 1
|
or self.cfg.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|||||||
@@ -20,14 +20,13 @@ from torch.utils.data import (
|
|||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
PackingMixin,
|
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
SchedulerMixin,
|
SchedulerMixin,
|
||||||
)
|
)
|
||||||
@@ -43,12 +42,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(
|
class AxolotlTrainer(
|
||||||
PackingMixin,
|
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
||||||
SchedulerMixin,
|
|
||||||
OptimizerMixin,
|
|
||||||
RngLoaderMixin,
|
|
||||||
CheckpointSaveMixin,
|
|
||||||
Trainer,
|
|
||||||
):
|
):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
@@ -122,15 +116,14 @@ class AxolotlTrainer(
|
|||||||
sequential=self.args.sample_packing_sequentially,
|
sequential=self.args.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=self.args.dataset_num_proc,
|
num_processes=self.args.dataset_num_proc,
|
||||||
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
len(sampler)
|
len(sampler)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
def _get_train_sampler(
|
def _get_train_sampler(
|
||||||
self, train_dataset: Dataset | None = None
|
self, train_dataset: Optional[Dataset] = None
|
||||||
) -> Sampler | None:
|
) -> Optional[Sampler]:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sample packing
|
Helper method to get the sampler for training. Handles cases for sample packing
|
||||||
and curriculum sampling (sequential).
|
and curriculum sampling (sequential).
|
||||||
@@ -139,22 +132,16 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
|
|
||||||
if train_dataset is None:
|
|
||||||
train_dataset = self.train_dataset
|
|
||||||
if train_dataset is None or not has_length(train_dataset):
|
|
||||||
return None
|
|
||||||
|
|
||||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
# Determine the base sampler first
|
# Determine the base sampler first
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
base_sampler = SequentialSampler(train_dataset)
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
elif use_sample_packing:
|
elif use_sample_packing:
|
||||||
base_sampler = RandomSampler(train_dataset)
|
base_sampler = RandomSampler(self.train_dataset)
|
||||||
else:
|
else:
|
||||||
# Default to parent class implementation for standard random sampling
|
# Default to parent class implementation for standard random sampling
|
||||||
return super()._get_train_sampler(train_dataset)
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
# Apply multipack wrapper if needed
|
||||||
if use_sample_packing:
|
if use_sample_packing:
|
||||||
@@ -173,10 +160,6 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
|
|
||||||
if eval_dataset is None or not has_length(eval_dataset):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
use_multipack = (
|
use_multipack = (
|
||||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
@@ -212,14 +195,6 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
if dataset.column_names and "length" in dataset.column_names:
|
if dataset.column_names and "length" in dataset.column_names:
|
||||||
dataset = dataset.remove_columns(["length"])
|
dataset = dataset.remove_columns(["length"])
|
||||||
if (
|
|
||||||
dataset.column_names
|
|
||||||
and "position_ids" in dataset.column_names
|
|
||||||
and "attention_mask" in dataset.column_names
|
|
||||||
and self.args.sample_packing
|
|
||||||
and self.args.sample_packing_drop_attention_mask
|
|
||||||
):
|
|
||||||
dataset = dataset.remove_columns(["attention_mask"])
|
|
||||||
|
|
||||||
if isinstance(dataset, datasets.Dataset):
|
if isinstance(dataset, datasets.Dataset):
|
||||||
if is_training:
|
if is_training:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
if cfg.dpo_padding_free is not None:
|
if cfg.dpo_padding_free is not None:
|
||||||
|
|||||||
@@ -5,6 +5,5 @@
|
|||||||
|
|
||||||
from .checkpoints import CheckpointSaveMixin
|
from .checkpoints import CheckpointSaveMixin
|
||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
from .packing import PackingMixin
|
|
||||||
from .rng_state_loader import RngLoaderMixin
|
from .rng_state_loader import RngLoaderMixin
|
||||||
from .scheduler import SchedulerMixin
|
from .scheduler import SchedulerMixin
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
"""Trainer mixin to support packing"""
|
|
||||||
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
|
|
||||||
class PackingMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Trainer mixin to support packing
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
|
||||||
super()._set_signature_columns_if_needed()
|
|
||||||
if (
|
|
||||||
self._signature_columns
|
|
||||||
and self.args.sample_packing
|
|
||||||
and self.args.sample_packing_drop_attention_mask
|
|
||||||
):
|
|
||||||
set_sig_columns = set(self._signature_columns)
|
|
||||||
set_sig_columns.remove("attention_mask")
|
|
||||||
self._signature_columns = list(set_sig_columns)
|
|
||||||
@@ -38,14 +38,6 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
sample_packing_mp_start_method: str | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The multiprocessing start method to use."},
|
|
||||||
)
|
|
||||||
sample_packing_drop_attention_mask: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Drop attention mask from inputs when using packing."},
|
|
||||||
)
|
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
|||||||
@@ -48,6 +48,13 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||||
|
|
||||||
|
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
||||||
|
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
||||||
|
LOG.info(
|
||||||
|
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
||||||
|
)
|
||||||
|
num_proc = 1
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from transformers import PreTrainedModel, Trainer
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
|
|||||||
@@ -19,11 +19,19 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
|
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
|
||||||
|
|
||||||
|
pip3 install --no-build-isolation -e .
|
||||||
|
```
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
@@ -31,29 +39,27 @@ plugins:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- cohere
|
- llama
|
||||||
- cohere2
|
- llama4
|
||||||
|
- llama4_text
|
||||||
|
- mllama
|
||||||
|
- phi3
|
||||||
- gemma
|
- gemma
|
||||||
- gemma2
|
- gemma2
|
||||||
- gemma3
|
- gemma3
|
||||||
- gemma3_text
|
- gemma3_text
|
||||||
- glm
|
|
||||||
- glm4
|
|
||||||
- llama
|
|
||||||
- llama4
|
|
||||||
- llama4_text
|
|
||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
- mllama
|
|
||||||
- phi
|
|
||||||
- phi3
|
|
||||||
- phi4_multimodal
|
|
||||||
- qwen2
|
- qwen2
|
||||||
- qwen2_vl
|
|
||||||
- qwen2_moe
|
- qwen2_moe
|
||||||
|
- qwen2_vl
|
||||||
- qwen2_5_vl
|
- qwen2_5_vl
|
||||||
- qwen3
|
- qwen3
|
||||||
- qwen3_moe
|
- qwen3_moe
|
||||||
|
- cohere
|
||||||
|
- cohere2
|
||||||
|
- glm
|
||||||
|
- glm4
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -64,28 +64,16 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"cut_cross_entropy.transformers"
|
"cut_cross_entropy.transformers"
|
||||||
)
|
)
|
||||||
if cce_spec_transformers is None:
|
if cce_spec_transformers is None:
|
||||||
raise ImportError(
|
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||||
"Transformers support is not installed. " + _CCE_INSTALL_MESSAGE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if Axolotl's cce fork is installed
|
|
||||||
try:
|
|
||||||
from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK
|
|
||||||
|
|
||||||
if not AXOLOTL_CCE_FORK:
|
|
||||||
raise ImportError
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Axolotl's fork of cut_cross_entropy is not installed. "
|
|
||||||
+ _CCE_INSTALL_MESSAGE
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
"""Apply cut cross entropy before model loading if enabled."""
|
"""Apply cut cross entropy before model loading if enabled."""
|
||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.patch import cce_patch
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||||
|
cce_patch,
|
||||||
|
)
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||||
|
|||||||
191
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
191
src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
"""Cohere and Cohere2 CCE patch."""
|
||||||
|
|
||||||
|
# This patch is based off transformers 4.50.0.
|
||||||
|
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
||||||
|
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
||||||
|
# operation is done internally and should be mathematically equivalent.
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||||
|
|
||||||
|
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||||
|
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||||
|
|
||||||
|
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>> # Generate
|
||||||
|
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
# scale hidden_states by logit_scale in-place of logits
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :] * self.logit_scale,
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
logits = logits * self.logit_scale # main diff from Llama
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_cohere(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.cohere import modeling_cohere
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_cohere.CohereForCausalLM
|
||||||
|
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_cohere2(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.cohere2 import modeling_cohere2
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
||||||
|
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
165
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
165
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""Gemma CCE patch"""
|
||||||
|
|
||||||
|
# This patch is based off transformers 4.50.0.
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.gemma.modeling_gemma import (
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||||
|
|
||||||
|
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||||
|
|
||||||
|
>>> prompt = "What is your favorite condiment?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is your favorite condiment?"
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma import modeling_gemma
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma.GemmaForCausalLM
|
||||||
|
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
447
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
447
src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py
Normal file
@@ -0,0 +1,447 @@
|
|||||||
|
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
||||||
|
|
||||||
|
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
||||||
|
# and updated for transformers 4.50.0.
|
||||||
|
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
||||||
|
# with both gemma3 (text and multimodal) models.
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache, HybridCache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
|
Gemma3CausalLMOutputWithPast,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[HybridCache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
defer_logits_calculation: bool = False,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
defer_logits_calculation (`bool`, *optional*):
|
||||||
|
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||||
|
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||||
|
|
||||||
|
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||||
|
|
||||||
|
>>> prompt = "What is your favorite condiment?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is your favorite condiment?"
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if self.config.final_logit_softcapping is not None:
|
||||||
|
logits = logits / self.config.final_logit_softcapping
|
||||||
|
logits = torch.tanh(logits)
|
||||||
|
logits = logits * self.config.final_logit_softcapping
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||||
|
|
||||||
|
>>> prompt = "answer en Where is the cow standing?"
|
||||||
|
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"answer en Where is the cow standing?\nbeach"
|
||||||
|
```"""
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
is_training = token_type_ids is not None and labels is not None
|
||||||
|
|
||||||
|
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||||
|
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||||
|
special_image_mask = input_ids == self.config.image_token_index
|
||||||
|
llm_input_ids = input_ids.clone()
|
||||||
|
llm_input_ids[special_image_mask] = 0
|
||||||
|
else:
|
||||||
|
llm_input_ids = input_ids # type: ignore
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = (
|
||||||
|
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
||||||
|
)
|
||||||
|
cache_position = torch.arange( # type: ignore
|
||||||
|
past_seen_tokens,
|
||||||
|
past_seen_tokens + inputs_embeds.shape[1],
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge text and images
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(pixel_values)
|
||||||
|
|
||||||
|
if input_ids is None:
|
||||||
|
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||||
|
torch.tensor(
|
||||||
|
self.config.image_token_index,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||||
|
-1
|
||||||
|
)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||||
|
inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not is_torchdynamo_compiling()
|
||||||
|
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||||
|
):
|
||||||
|
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of images does not match number of special image tokens in the input text. "
|
||||||
|
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||||
|
"tokens from image embeddings."
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||||
|
|
||||||
|
# mask out pad-token-ids in labels for BC
|
||||||
|
if labels is not None and self.pad_token_id in labels:
|
||||||
|
logger.warning_once(
|
||||||
|
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||||
|
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||||
|
)
|
||||||
|
labels = torch.where( # type: ignore
|
||||||
|
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||||
|
)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||||
|
attention_mask,
|
||||||
|
token_type_ids,
|
||||||
|
past_key_values,
|
||||||
|
cache_position,
|
||||||
|
inputs_embeds,
|
||||||
|
is_training,
|
||||||
|
)
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
defer_logits_calculation=True, # enable deferred logits calculation
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.language_model.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = hidden_states
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
shift_logits = logits[..., :-1, :]
|
||||||
|
shift_labels = labels[..., 1:]
|
||||||
|
if attention_mask is not None:
|
||||||
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||||
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||||
|
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
||||||
|
logits.device
|
||||||
|
)
|
||||||
|
shift_logits = shift_logits[
|
||||||
|
shift_attention_mask.to(logits.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
shift_labels = shift_labels[
|
||||||
|
shift_attention_mask.to(shift_labels.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = shift_logits.contiguous()
|
||||||
|
shift_labels = shift_labels.contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||||
|
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||||
|
loss = loss_fct(flat_logits, flat_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Gemma3CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma2(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma2 import modeling_gemma2
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
||||||
|
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma3_text(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma3 import modeling_gemma3
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
||||||
|
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma3(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.gemma3 import modeling_gemma3
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
||||||
|
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
maybe_model.language_model.forward = MethodType(
|
||||||
|
cce_forward, maybe_model.language_model
|
||||||
|
)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
"""GLM 4 patch. GLM family inherits from Llama."""
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_glm(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
|
||||||
|
# Set the _PATCH_OPTS in the llama patch file
|
||||||
|
import cut_cross_entropy.transformers.llama as llama_patch
|
||||||
|
|
||||||
|
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
from cut_cross_entropy.transformers.llama import cce_forward
|
||||||
|
from transformers.models.glm import modeling_glm
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_glm.GlmForCausalLM
|
||||||
|
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_glm.GlmForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_glm4(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
|
||||||
|
# Set the _PATCH_OPTS in the llama patch file
|
||||||
|
import cut_cross_entropy.transformers.llama as llama_patch
|
||||||
|
|
||||||
|
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
from cut_cross_entropy.transformers.llama import cce_forward
|
||||||
|
from transformers.models.glm4 import modeling_glm4
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_glm4.Glm4ForCausalLM
|
||||||
|
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
164
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py
Normal file
164
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> CausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs: BaseModelOutputWithPast = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError("hidden_states is None")
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
"""Patch Llama for CCE."""
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.llama import modeling_llama
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_llama.LlamaForCausalLM
|
||||||
|
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
401
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py
Normal file
401
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py
Normal file
@@ -0,0 +1,401 @@
|
|||||||
|
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.llama4.modeling_llama4 import (
|
||||||
|
Llama4CausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
defer_logits_calculation: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
|
||||||
|
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||||
|
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
||||||
|
|
||||||
|
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||||
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||||
|
vision_feature_select_strategy: Optional[str] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
image_sizes: torch.Tensor | None = None,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||||
|
|
||||||
|
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer
|
||||||
|
if vision_feature_layer is not None
|
||||||
|
else self.config.vision_config.vision_feature_layer
|
||||||
|
)
|
||||||
|
vision_feature_select_strategy = (
|
||||||
|
vision_feature_select_strategy
|
||||||
|
if vision_feature_select_strategy is not None
|
||||||
|
else self.config.vision_config.vision_feature_select_strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=vision_feature_layer,
|
||||||
|
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
||||||
|
|
||||||
|
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||||
|
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||||
|
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||||
|
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
||||||
|
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
||||||
|
|
||||||
|
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||||
|
num_tokens_to_fill = final_mask_1d.sum()
|
||||||
|
|
||||||
|
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||||
|
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(
|
||||||
|
expanded_mask, projected_vision_flat
|
||||||
|
) # type: ignore
|
||||||
|
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
defer_logits_calculation=True, # enable deferred logits calculation
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
# TODO: check if need to handle attention_mask
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.language_model.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = hidden_states
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
if attention_mask is not None:
|
||||||
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||||
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||||
|
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||||
|
logits.device
|
||||||
|
)
|
||||||
|
shift_logits = logits[..., :-1, :][
|
||||||
|
shift_attention_mask.to(logits.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
shift_labels = labels[..., 1:][
|
||||||
|
shift_attention_mask.to(labels.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1).to(shift_logits.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Llama4CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits, # type: ignore # TODO: check if need to create dummy logits
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama4_text(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.llama4 import modeling_llama4
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_llama4.Llama4ForCausalLM
|
||||||
|
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
modeling_llama4.Llama4ForCausalLM,
|
||||||
|
"forward",
|
||||||
|
cce_forward,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama4(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.llama4 import modeling_llama4
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
|
||||||
|
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
# patch the language model
|
||||||
|
maybe_model.language_model.forward = MethodType(
|
||||||
|
cce_forward, maybe_model.language_model
|
||||||
|
)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
modeling_llama4.Llama4ForConditionalGeneration,
|
||||||
|
"forward",
|
||||||
|
cce_forward_multimodal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# patch the causal language model
|
||||||
|
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
||||||
|
return None
|
||||||
@@ -0,0 +1,384 @@
|
|||||||
|
"""Mistral and Mistral3 CCE patch."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.mistral3.modeling_mistral3 import (
|
||||||
|
Mistral3CausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
is_torchdynamo_compiling,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] | None = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
defer_logits_calculation: bool = False,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
defer_logits_calculation (`bool`, *optional*):
|
||||||
|
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||||
|
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||||
|
|
||||||
|
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
pixel_values: torch.FloatTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
image_sizes: torch.Tensor | None = None,
|
||||||
|
**lm_kwargs,
|
||||||
|
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
|
|
||||||
|
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||||
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
vision_feature_layer = (
|
||||||
|
vision_feature_layer
|
||||||
|
if vision_feature_layer is not None
|
||||||
|
else self.config.vision_feature_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
image_features = self.get_image_features(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
vision_feature_layer=vision_feature_layer,
|
||||||
|
image_sizes=image_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||||
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||||
|
inputs_embeds.device
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not is_torchdynamo_compiling()
|
||||||
|
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||||
|
):
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||||
|
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
defer_logits_calculation=True, # enable deferred logits calculation
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.language_model.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**lm_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = hidden_states
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
if attention_mask is not None:
|
||||||
|
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||||
|
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||||
|
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||||
|
logits.device
|
||||||
|
)
|
||||||
|
shift_logits = logits[..., :-1, :][
|
||||||
|
shift_attention_mask.to(logits.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
shift_labels = labels[..., 1:][
|
||||||
|
shift_attention_mask.to(labels.device) != 0
|
||||||
|
].contiguous()
|
||||||
|
else:
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
|
shift_labels.view(-1).to(shift_logits.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Mistral3CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
image_hidden_states=image_features if pixel_values is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mistral(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.mistral import modeling_mistral
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_mistral.MistralForCausalLM
|
||||||
|
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mistral3(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.mistral import modeling_mistral
|
||||||
|
from transformers.models.mistral3 import modeling_mistral3
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
||||||
|
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
maybe_model.language_model.forward = MethodType(
|
||||||
|
cce_forward, maybe_model.language_model
|
||||||
|
)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
# patch the causal model to enable deferred logits calculation
|
||||||
|
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
366
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
366
src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
"""Mllama CCE patch."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.mllama.modeling_mllama import (
|
||||||
|
_prepare_cross_attention_mask,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
defer_logits_calculation: bool = False,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
defer_logits_calculation (`bool`, *optional*):
|
||||||
|
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||||
|
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
||||||
|
|
||||||
|
>>> prompt = "If I had to write a haiku, it would be:"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
||||||
|
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
>>> print(result)
|
||||||
|
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
||||||
|
I love the idea of snowflakes gently falling, each one
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cross_attention_mask=cross_attention_mask,
|
||||||
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||||
|
# defer logits calculation to the ConditionalGeneration forward
|
||||||
|
logits = hidden_states[:, slice_indices, :]
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
|
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||||
|
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
||||||
|
|
||||||
|
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||||
|
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
||||||
|
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
||||||
|
|
||||||
|
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> output = model.generate(**inputs, max_new_tokens=15)
|
||||||
|
|
||||||
|
>>> prompt_len = inputs.input_ids.shape[-1]
|
||||||
|
>>> generated_ids = output[:, prompt_len:]
|
||||||
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
|
>>> print(generated_text)
|
||||||
|
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if pixel_values is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||||
|
)
|
||||||
|
|
||||||
|
if pixel_values is not None and cross_attention_states is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||||
|
)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
if aspect_ratio_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||||
|
)
|
||||||
|
# get vision tokens from vision model
|
||||||
|
vision_outputs = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
aspect_ratio_ids=aspect_ratio_ids,
|
||||||
|
aspect_ratio_mask=aspect_ratio_mask,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
cross_attention_states = vision_outputs[0]
|
||||||
|
cross_attention_states = self.multi_modal_projector(
|
||||||
|
cross_attention_states
|
||||||
|
).reshape(
|
||||||
|
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
if cross_attention_mask is not None:
|
||||||
|
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||||
|
_prepare_cross_attention_mask(
|
||||||
|
cross_attention_mask,
|
||||||
|
num_vision_tokens=self.vision_model.num_patches,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
full_text_row_masked_out_mask = None
|
||||||
|
|
||||||
|
if cross_attention_mask is not None and cache_position is not None:
|
||||||
|
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||||
|
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||||
|
:, :, cache_position
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = self.language_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
cross_attention_mask=cross_attention_mask,
|
||||||
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
logits_to_keep=logits_to_keep,
|
||||||
|
defer_logits_calculation=True, # enable deferred logits calculation
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.language_model.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||||
|
logits = hidden_states
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (loss,) + outputs if loss is not None else outputs
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=outputs.logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mllama(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.mllama import modeling_mllama
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
||||||
|
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
# patch the language model
|
||||||
|
maybe_model.language_model.forward = MethodType(
|
||||||
|
cce_forward, maybe_model.language_model
|
||||||
|
)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
|
||||||
|
# patch the causal language model
|
||||||
|
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
126
src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py
Normal file
126
src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
"""Cut Cross Entropy patcher"""
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||||
|
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||||
|
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||||
|
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||||
|
patch_cohere,
|
||||||
|
patch_cohere2,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||||
|
patch_gemma2,
|
||||||
|
patch_gemma3,
|
||||||
|
patch_gemma3_text,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
||||||
|
patch_glm,
|
||||||
|
patch_glm4,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||||
|
patch_llama,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
||||||
|
patch_llama4,
|
||||||
|
patch_llama4_text,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
||||||
|
patch_mistral,
|
||||||
|
patch_mistral3,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
|
||||||
|
patch_qwen2,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
|
||||||
|
patch_qwen2_5_vl,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
|
||||||
|
patch_qwen2_moe,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
|
||||||
|
patch_qwen2_vl,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
|
||||||
|
patch_qwen3_moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||||
|
"llama": patch_llama,
|
||||||
|
"llama4": patch_llama4,
|
||||||
|
"llama4_text": patch_llama4_text,
|
||||||
|
"mllama": patch_mllama,
|
||||||
|
"phi3": patch_phi3,
|
||||||
|
"gemma": patch_gemma,
|
||||||
|
"gemma2": patch_gemma2,
|
||||||
|
"gemma3": patch_gemma3,
|
||||||
|
"gemma3_text": patch_gemma3_text,
|
||||||
|
"mistral": patch_mistral,
|
||||||
|
"mistral3": patch_mistral3,
|
||||||
|
"qwen2": patch_qwen2,
|
||||||
|
"qwen2_moe": patch_qwen2_moe,
|
||||||
|
"qwen2_vl": patch_qwen2_vl,
|
||||||
|
"qwen2_5_vl": patch_qwen2_5_vl,
|
||||||
|
"qwen3": patch_qwen3,
|
||||||
|
"qwen3_moe": patch_qwen3_moe,
|
||||||
|
"cohere": patch_cohere,
|
||||||
|
"cohere2": patch_cohere2,
|
||||||
|
"glm": patch_glm,
|
||||||
|
"glm4": patch_glm4,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def cce_patch(
|
||||||
|
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
||||||
|
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
||||||
|
reduction: str = "mean",
|
||||||
|
filter_eps: float | str | None = "auto",
|
||||||
|
accum_e_fp32: bool = False,
|
||||||
|
accum_c_fp32: bool = False,
|
||||||
|
filter_e_grad: bool = True,
|
||||||
|
filter_c_grad: bool = True,
|
||||||
|
train_only: bool = False,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
if isinstance(impl, LinearCrossEntropyImpl):
|
||||||
|
impl = impl.name.lower()
|
||||||
|
|
||||||
|
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
||||||
|
raise ValueError(f"Unknown {impl=}")
|
||||||
|
|
||||||
|
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
||||||
|
if hasattr(model_type_or_model, "config"):
|
||||||
|
model_type = getattr(
|
||||||
|
getattr(model_type_or_model, "config", None), "model_type", None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
|
||||||
|
)
|
||||||
|
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
||||||
|
model_type = model_type_or_model.model_type
|
||||||
|
else:
|
||||||
|
model_type = model_type_or_model
|
||||||
|
|
||||||
|
patch_options = PatchOptions(
|
||||||
|
impl=impl,
|
||||||
|
reduction=reduction,
|
||||||
|
filter_eps=filter_eps,
|
||||||
|
accum_e_fp32=accum_e_fp32,
|
||||||
|
accum_c_fp32=accum_c_fp32,
|
||||||
|
filter_e_grad=filter_e_grad,
|
||||||
|
filter_c_grad=filter_c_grad,
|
||||||
|
train_only=train_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||||
|
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
||||||
|
model_type_or_model, patch_options
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(f"Unknown model type {model_type}")
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
from transformers.models.qwen2 import modeling_qwen2
|
||||||
|
|
||||||
|
# Set the _PATCH_OPTS in the llama patch file
|
||||||
|
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||||
|
|
||||||
|
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||||
|
cce_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2.Qwen2ForCausalLM
|
||||||
|
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,246 @@
|
|||||||
|
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLCausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
|
n_image_features = image_embeds.shape[0]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = input_ids == self.config.image_token_id
|
||||||
|
mask_unsqueezed = mask.unsqueeze(-1)
|
||||||
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
|
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
|
n_video_features = video_embeds.shape[0]
|
||||||
|
if n_video_tokens != n_video_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = input_ids == self.config.video_token_id
|
||||||
|
mask_unsqueezed = mask.unsqueeze(-1)
|
||||||
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
|
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||||
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||||
|
# calculate RoPE index once per generation in the pre-fill stage only
|
||||||
|
if (
|
||||||
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
|
or self.rope_deltas is None
|
||||||
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||||
|
):
|
||||||
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
|
input_ids,
|
||||||
|
image_grid_thw,
|
||||||
|
video_grid_thw,
|
||||||
|
second_per_grid_ts,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
self.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
else:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
delta = (
|
||||||
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||||
|
if cache_position is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||||
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||||
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||||
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||||
|
position_ids = position_ids.add(delta) # type: ignore
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Qwen2_5_VLCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
rope_deltas=self.rope_deltas,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2_5_vl(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
||||||
|
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
|
||||||
|
cce_forward_multimodal
|
||||||
|
)
|
||||||
|
return None
|
||||||
@@ -0,0 +1,178 @@
|
|||||||
|
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||||
|
MoeCausalLMOutputWithPast,
|
||||||
|
MoeModelOutputWithPast,
|
||||||
|
load_balancing_loss_func,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> MoeCausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
|
||||||
|
|
||||||
|
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs: MoeModelOutputWithPast = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError("hidden_states is None")
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(
|
||||||
|
outputs.router_logits,
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||||
|
loss.device # type: ignore
|
||||||
|
) # make sure to reside in the same device
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss, # type: ignore
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
router_logits=outputs.router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2_moe(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
|
||||||
|
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(forward, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,239 @@
|
|||||||
|
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||||
|
Qwen2VLCausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
|
n_image_features = image_embeds.shape[0]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
image_mask = (
|
||||||
|
(input_ids == self.config.image_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
|
n_video_features = video_embeds.shape[0]
|
||||||
|
if n_video_tokens != n_video_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||||
|
)
|
||||||
|
video_mask = (
|
||||||
|
(input_ids == self.config.video_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||||
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||||
|
# calculate RoPE index once per generation in the pre-fill stage only
|
||||||
|
if (
|
||||||
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
|
or self.rope_deltas is None
|
||||||
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||||
|
):
|
||||||
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
|
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||||
|
)
|
||||||
|
self.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
else:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
delta = (
|
||||||
|
cache_position[0] + self.rope_deltas
|
||||||
|
if cache_position is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||||
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||||
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||||
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||||
|
delta = delta.to(position_ids.device) # type: ignore
|
||||||
|
position_ids = position_ids.add(delta) # type: ignore
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Qwen2VLCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
rope_deltas=self.rope_deltas,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2_vl(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
|
||||||
|
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
return None
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
from transformers.models.qwen3 import modeling_qwen3
|
||||||
|
|
||||||
|
# Set the _PATCH_OPTS in the llama patch file
|
||||||
|
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||||
|
|
||||||
|
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen3.Qwen3ForCausalLM
|
||||||
|
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||||
|
KwargsForCausalLM,
|
||||||
|
MoeCausalLMOutputWithPast,
|
||||||
|
MoeModelOutputWithPast,
|
||||||
|
load_balancing_loss_func,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> MoeCausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
||||||
|
|
||||||
|
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs: MoeModelOutputWithPast = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError("hidden_states is None")
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(
|
||||||
|
outputs.router_logits,
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||||
|
loss.device # type: ignore
|
||||||
|
) # make sure to reside in the same device
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss, # type: ignore
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
router_logits=outputs.router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_moe(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
|
||||||
|
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(forward, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
"""Monkeypatch for apply_lce to add softcap."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from cut_cross_entropy import linear_cross_entropy
|
||||||
|
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||||
|
|
||||||
|
|
||||||
|
def apply_lce(
|
||||||
|
e: torch.Tensor,
|
||||||
|
c: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
opts: PatchOptions,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
softcap: float | None = None,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||||
|
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||||
|
cce_kwargs = opts.to_kwargs()
|
||||||
|
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||||
|
cce_kwargs["reduction"] = "sum"
|
||||||
|
else:
|
||||||
|
num_items_in_batch = None
|
||||||
|
|
||||||
|
loss = linear_cross_entropy(
|
||||||
|
e,
|
||||||
|
c,
|
||||||
|
labels.to(e.device),
|
||||||
|
bias=bias,
|
||||||
|
shift=True,
|
||||||
|
softcap=softcap,
|
||||||
|
**cce_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_items_in_batch is not None:
|
||||||
|
loss = loss / num_items_in_batch
|
||||||
|
|
||||||
|
return loss
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
# DenseMixer
|
|
||||||
|
|
||||||
See [DenseMixer](https://github.com/yaof20/DenseMixer/)
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
|
|
||||||
Simply add the following to your axolotl YAML config:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.densemixer.DenseMixerPlugin
|
|
||||||
```
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Integration entry point for the DenseMixer plugin."""
|
|
||||||
|
|
||||||
from .plugin import DenseMixerPlugin
|
|
||||||
|
|
||||||
__all__ = ["DenseMixerPlugin"]
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
"""Pydantic models for DenseMixer plugin"""
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class DenseMixerArgs(BaseModel):
|
|
||||||
"""
|
|
||||||
Args for DenseMixer
|
|
||||||
"""
|
|
||||||
|
|
||||||
dense_mixer: bool = True
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
"""DenseMixer plugin for Axolotl"""
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DenseMixerPlugin(BasePlugin):
|
|
||||||
"""
|
|
||||||
Plugin for DenseMixer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_input_args(self) -> str | None:
|
|
||||||
return "axolotl.integrations.densemixer.args.DenseMixerArgs"
|
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
|
||||||
"""Apply densemixer patches before model loading if enabled."""
|
|
||||||
if cfg.dense_mixer:
|
|
||||||
if not importlib.util.find_spec("densemixer"):
|
|
||||||
raise RuntimeError(
|
|
||||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
|
||||||
)
|
|
||||||
|
|
||||||
from densemixer.patching import (
|
|
||||||
apply_olmoe_patch,
|
|
||||||
apply_qwen2_moe_patch,
|
|
||||||
apply_qwen3_moe_patch,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"Applying DenseMixer patches for model type: {cfg.model_config_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.model_config_type == "olmoe":
|
|
||||||
apply_olmoe_patch()
|
|
||||||
if cfg.model_config_type == "qwen2_moe":
|
|
||||||
apply_qwen2_moe_patch()
|
|
||||||
if cfg.model_config_type == "qwen3_moe":
|
|
||||||
apply_qwen3_moe_patch()
|
|
||||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
|||||||
kd_alpha: 0.9
|
kd_alpha: 0.9
|
||||||
kd_temperature: 1.0
|
kd_temperature: 1.0
|
||||||
|
|
||||||
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
torch_compile: True # torch>=2.5.1, recommended to reduce vram
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from axolotl.utils.logging import get_logger
|
|||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
from .utils import patch_with_compile_disable
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__, use_environ=True)
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|||||||
@@ -504,9 +504,6 @@ class ModelLoader:
|
|||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
if self.cfg.model_config_type == "falcon_h1":
|
|
||||||
# output projection cannot be quantized for Falcon-H1 models
|
|
||||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
|
||||||
|
|
||||||
if self.cfg.bnb_config_kwargs:
|
if self.cfg.bnb_config_kwargs:
|
||||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||||
@@ -521,9 +518,6 @@ class ModelLoader:
|
|||||||
# Exclude mamba blocks from int8 quantization for jamba
|
# Exclude mamba blocks from int8 quantization for jamba
|
||||||
if self.cfg.model_config_type == "jamba":
|
if self.cfg.model_config_type == "jamba":
|
||||||
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||||
if self.cfg.model_config_type == "falcon_h1":
|
|
||||||
# output projection cannot be quantized for Falcon-H1 models
|
|
||||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
@@ -776,9 +770,6 @@ class ModelLoader:
|
|||||||
dist_dtype: torch.dtype,
|
dist_dtype: torch.dtype,
|
||||||
before_kbit_train_or_finetune: bool,
|
before_kbit_train_or_finetune: bool,
|
||||||
):
|
):
|
||||||
dest = {"dtype": dist_dtype}
|
|
||||||
if self.cfg.lora_on_cpu:
|
|
||||||
dest["device"] = "cpu"
|
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
@@ -789,4 +780,4 @@ class ModelLoader:
|
|||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
module.to(**dest)
|
module.to(dist_dtype)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import importlib.util
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
import torch
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
@@ -50,11 +49,10 @@ class PatchManager:
|
|||||||
|
|
||||||
def apply_pre_model_load_patches(self):
|
def apply_pre_model_load_patches(self):
|
||||||
"""Apply pre-model load patches based on config."""
|
"""Apply pre-model load patches based on config."""
|
||||||
# self._apply_flex_attention_patches()
|
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
self._apply_chunked_cross_entropy_patch()
|
|
||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
|
self._apply_flex_attention_patches()
|
||||||
self._apply_model_specific_patches()
|
self._apply_model_specific_patches()
|
||||||
self._apply_fp8_patches()
|
self._apply_fp8_patches()
|
||||||
self._apply_flash_attention_peft_patches()
|
self._apply_flash_attention_peft_patches()
|
||||||
@@ -65,9 +63,6 @@ class PatchManager:
|
|||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_gemma3_conditional_generation_forward_patch()
|
|
||||||
self._apply_sequence_parallel_patches()
|
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -83,15 +78,6 @@ class PatchManager:
|
|||||||
patch_xformers_attn_over_fa2()
|
patch_xformers_attn_over_fa2()
|
||||||
self.cfg.flash_attention = True
|
self.cfg.flash_attention = True
|
||||||
|
|
||||||
def _apply_chunked_cross_entropy_patch(self):
|
|
||||||
if self.cfg.chunked_cross_entropy:
|
|
||||||
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
|
||||||
|
|
||||||
if self.cfg.chunked_cross_entropy_num_chunks:
|
|
||||||
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
|
|
||||||
else:
|
|
||||||
patch_chunked_ce_loss_fn()
|
|
||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||||
@@ -99,14 +85,6 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_accelerate_fsdp2()
|
patch_accelerate_fsdp2()
|
||||||
|
|
||||||
# if self.cfg.fsdp_config:
|
|
||||||
# # see transformers#39152
|
|
||||||
# from axolotl.monkeypatch.trainer_fsdp_optim import (
|
|
||||||
# patch_training_loop_for_fsdp,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# patch_training_loop_for_fsdp()
|
|
||||||
|
|
||||||
def _apply_adapter_patches(self):
|
def _apply_adapter_patches(self):
|
||||||
"""Apply patches for adapter configurations."""
|
"""Apply patches for adapter configurations."""
|
||||||
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||||
@@ -117,20 +95,14 @@ class PatchManager:
|
|||||||
def _apply_flex_attention_patches(self):
|
def _apply_flex_attention_patches(self):
|
||||||
"""Apply patches for flexible attention."""
|
"""Apply patches for flexible attention."""
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.flex_attention:
|
||||||
# from axolotl.monkeypatch.attention.flex_attn import (
|
from axolotl.monkeypatch.attention.flex_attn import (
|
||||||
# patch_flex_make_mask,
|
patch_flex_make_mask,
|
||||||
# patch_flex_wrapper,
|
patch_flex_wrapper,
|
||||||
# )
|
)
|
||||||
#
|
|
||||||
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
|
||||||
# patch_flex_wrapper(**flex_attn_compile_kwargs)
|
|
||||||
# patch_flex_make_mask()
|
|
||||||
if self.cfg.sample_packing:
|
|
||||||
from axolotl.core.attention.flex_block_mask import (
|
|
||||||
patch_create_causal_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_create_causal_mask(self.cfg.model_config_type)
|
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||||
|
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||||
|
patch_flex_make_mask()
|
||||||
|
|
||||||
def _apply_model_specific_patches(self):
|
def _apply_model_specific_patches(self):
|
||||||
"""Apply patches specific to model architectures."""
|
"""Apply patches specific to model architectures."""
|
||||||
@@ -166,25 +138,10 @@ class PatchManager:
|
|||||||
"""Apply patches for gradient checkpointing."""
|
"""Apply patches for gradient checkpointing."""
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||||
CheckpointFunctionWithCPUOffload,
|
|
||||||
hf_grad_checkpoint_offload_wrapper,
|
hf_grad_checkpoint_offload_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
self.cfg.gradient_checkpointing_kwargs
|
|
||||||
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
|
|
||||||
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
|
|
||||||
):
|
|
||||||
transformers.modeling_utils.checkpoint = (
|
|
||||||
hf_grad_checkpoint_offload_wrapper
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
transformers.modeling_utils.checkpoint.CheckpointFunction = (
|
|
||||||
CheckpointFunctionWithCPUOffload
|
|
||||||
)
|
|
||||||
torch.utils.checkpoint.CheckpointFunction = (
|
|
||||||
CheckpointFunctionWithCPUOffload
|
|
||||||
)
|
|
||||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||||
hf_grad_checkpoint_disk_offload_wrapper,
|
hf_grad_checkpoint_disk_offload_wrapper,
|
||||||
@@ -254,32 +211,6 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_gemma3_conditional_generation_forward_patch(self):
|
|
||||||
"""Apply gemma3 conditional generation forward patch."""
|
|
||||||
if self.model_config.model_type in ["gemma3", "gemma3_text"]:
|
|
||||||
from axolotl.monkeypatch.models.gemma3.modeling import (
|
|
||||||
patch_gemma3_conditional_generation_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_gemma3_conditional_generation_forward()
|
|
||||||
|
|
||||||
def _apply_sequence_parallel_patches(self):
|
|
||||||
"""Apply sequence parallelism patches."""
|
|
||||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
|
||||||
from axolotl.monkeypatch.ring_attn.patch import (
|
|
||||||
patch_prepare_data_loader,
|
|
||||||
patch_prepare_device_mesh,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_prepare_data_loader()
|
|
||||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
|
||||||
if self.cfg.tiled_mlp:
|
|
||||||
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
|
||||||
|
|
||||||
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
|
|
||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
{"additional_special_tokens": additional_special_tokens}
|
{"additional_special_tokens": additional_special_tokens}
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process(use_environ=True):
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
|
|||||||
@@ -5,8 +5,7 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
|
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (
|
||||||
CheckpointFunctionWithCPUOffload,
|
|
||||||
CPU_Offloaded_Gradient_Checkpointer,
|
CPU_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
||||||
|
|||||||
@@ -13,24 +13,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.utils.checkpoint import (
|
|
||||||
_get_autocast_kwargs,
|
|
||||||
_get_device_module,
|
|
||||||
_infer_device_type,
|
|
||||||
check_backward_validity,
|
|
||||||
detach_variable,
|
|
||||||
get_device_states,
|
|
||||||
set_device_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
# support different pytorch versions
|
|
||||||
has_device_type = "device_type" in inspect.signature(set_device_states).parameters
|
|
||||||
|
|
||||||
torch_version = version.parse(torch.__version__)
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
@@ -76,153 +60,3 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
) + (
|
) + (
|
||||||
None,
|
None,
|
||||||
) * len(ctx.args)
|
) * len(ctx.args)
|
||||||
|
|
||||||
|
|
||||||
# Copyright 2025 Snowflake Inc.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
|
|
||||||
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
|
|
||||||
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
|
||||||
check_backward_validity(args)
|
|
||||||
ctx.run_function = run_function
|
|
||||||
ctx.preserve_rng_state = preserve_rng_state
|
|
||||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
|
||||||
ctx.device_type = _infer_device_type(*args)
|
|
||||||
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
|
||||||
ctx.device_type
|
|
||||||
)
|
|
||||||
if preserve_rng_state:
|
|
||||||
ctx.fwd_cpu_state = torch.get_rng_state()
|
|
||||||
# Don't eagerly initialize the cuda context by accident.
|
|
||||||
# (If the user intends that the context is initialized later, within their
|
|
||||||
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
|
||||||
# we have no way to anticipate this will happen before we run the function.)
|
|
||||||
ctx.had_device_in_fwd = False
|
|
||||||
device_module = _get_device_module(ctx.device_type)
|
|
||||||
if getattr(device_module, "_initialized", False):
|
|
||||||
ctx.had_device_in_fwd = True
|
|
||||||
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
|
|
||||||
|
|
||||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
|
||||||
# to be filled out during the backward.
|
|
||||||
ctx.inputs = []
|
|
||||||
ctx.tensor_indices = []
|
|
||||||
tensor_inputs = []
|
|
||||||
# x = None
|
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if torch.is_tensor(arg):
|
|
||||||
# cpu-offload
|
|
||||||
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
|
|
||||||
# upstream could accept a list of arg indices to offload
|
|
||||||
if i == 0:
|
|
||||||
# print(f"{arg.shape=}")
|
|
||||||
ctx.x_device = arg.device
|
|
||||||
ctx.x_requires_grad = arg.requires_grad
|
|
||||||
t = arg.detach().cpu()
|
|
||||||
else:
|
|
||||||
t = arg
|
|
||||||
tensor_inputs.append(t)
|
|
||||||
ctx.tensor_indices.append(i)
|
|
||||||
ctx.inputs.append(None)
|
|
||||||
else:
|
|
||||||
ctx.inputs.append(arg)
|
|
||||||
|
|
||||||
ctx.save_for_backward(*tensor_inputs)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = run_function(*args)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *args):
|
|
||||||
if (
|
|
||||||
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
|
|
||||||
" with .grad() or passing an `inputs` parameter to .backward()."
|
|
||||||
" To resolve this error, you can either set use_reentrant=False,"
|
|
||||||
" or call .backward() without passing the `inputs` argument."
|
|
||||||
)
|
|
||||||
# Copy the list to avoid modifying original list.
|
|
||||||
inputs = list(ctx.inputs)
|
|
||||||
tensor_indices = ctx.tensor_indices
|
|
||||||
tensors = ctx.saved_tensors
|
|
||||||
|
|
||||||
# Fill in inputs with appropriate saved tensors.
|
|
||||||
for i, idx in enumerate(tensor_indices):
|
|
||||||
if i == 0:
|
|
||||||
t = (
|
|
||||||
tensors[i]
|
|
||||||
.to(ctx.x_device)
|
|
||||||
.detach()
|
|
||||||
.requires_grad_(ctx.x_requires_grad)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
t = tensors[i]
|
|
||||||
inputs[idx] = t
|
|
||||||
|
|
||||||
# Stash the surrounding rng state, and mimic the state that was
|
|
||||||
# present at this time during forward. Restore the surrounding state
|
|
||||||
# when we're done.
|
|
||||||
rng_devices = []
|
|
||||||
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
|
||||||
rng_devices = ctx.fwd_devices
|
|
||||||
with torch.random.fork_rng(
|
|
||||||
devices=rng_devices,
|
|
||||||
enabled=ctx.preserve_rng_state,
|
|
||||||
device_type=ctx.device_type,
|
|
||||||
):
|
|
||||||
if ctx.preserve_rng_state:
|
|
||||||
torch.set_rng_state(ctx.fwd_cpu_state)
|
|
||||||
if ctx.had_device_in_fwd:
|
|
||||||
if has_device_type:
|
|
||||||
# newer pytorch (as early as 2.7)
|
|
||||||
set_device_states(
|
|
||||||
ctx.fwd_devices,
|
|
||||||
ctx.fwd_device_states,
|
|
||||||
device_type=ctx.device_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# older pytorch (at least 2.4)
|
|
||||||
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
|
||||||
detached_inputs = detach_variable(tuple(inputs))
|
|
||||||
|
|
||||||
device_autocast_ctx = (
|
|
||||||
torch.amp.autocast(
|
|
||||||
device_type=ctx.device_type, **ctx.device_autocast_kwargs
|
|
||||||
)
|
|
||||||
if torch.amp.is_autocast_available(ctx.device_type)
|
|
||||||
else contextlib.nullcontext()
|
|
||||||
)
|
|
||||||
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
|
|
||||||
outputs = ctx.run_function(*detached_inputs)
|
|
||||||
|
|
||||||
if isinstance(outputs, torch.Tensor):
|
|
||||||
outputs = (outputs,)
|
|
||||||
|
|
||||||
# run backward() with only tensor that requires grad
|
|
||||||
outputs_with_grad = []
|
|
||||||
args_with_grad = []
|
|
||||||
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
|
|
||||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
|
||||||
outputs_with_grad.append(outputs[i])
|
|
||||||
args_with_grad.append(args[i])
|
|
||||||
if len(outputs_with_grad) == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
"none of output has requires_grad=True, this checkpoint() is not necessary"
|
|
||||||
)
|
|
||||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
|
||||||
grads = tuple(
|
|
||||||
inp.grad if isinstance(inp, torch.Tensor) else None
|
|
||||||
for inp in detached_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
return (None, None) + grads
|
|
||||||
|
|||||||
@@ -1,134 +0,0 @@
|
|||||||
"""
|
|
||||||
chunked ce loss
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
|
|
||||||
class CEWithChunkedOutputLoss(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
|
|
||||||
|
|
||||||
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
|
|
||||||
super().__init__()
|
|
||||||
self.num_output_chunks = num_output_chunks
|
|
||||||
self.ignore_index = ignore_index
|
|
||||||
|
|
||||||
def compute_cross_entropy(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
normalize: bool = True, # pylint: disable=unused-argument
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Upcast logits to fp32 and compute cross entropy loss.
|
|
||||||
"""
|
|
||||||
return F.cross_entropy(
|
|
||||||
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
logits (List[torch.Tensor]): List of chunked logits of length
|
|
||||||
``self.num_output_chunks``, where each chunk has shape
|
|
||||||
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
|
|
||||||
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
|
|
||||||
reduction (str): The reduction to apply to the output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Cross entropy loss of shape (1,).
|
|
||||||
"""
|
|
||||||
|
|
||||||
total_elements = (labels != self.ignore_index).sum()
|
|
||||||
|
|
||||||
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
|
|
||||||
labels = [
|
|
||||||
target_chunk.reshape(-1)
|
|
||||||
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
|
|
||||||
]
|
|
||||||
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
|
|
||||||
logits = [
|
|
||||||
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
|
|
||||||
]
|
|
||||||
|
|
||||||
# compute one chunk at a time
|
|
||||||
total_loss = 0.0
|
|
||||||
for logits_chunk, labels_chunk in zip(logits, labels):
|
|
||||||
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
|
|
||||||
|
|
||||||
if reduction == "sum":
|
|
||||||
return total_loss
|
|
||||||
return total_loss / total_elements
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
|
||||||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
|
||||||
loss_fn_ce.compute_cross_entropy = torch.compile(
|
|
||||||
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
|
||||||
)
|
|
||||||
return loss_fn_ce
|
|
||||||
|
|
||||||
|
|
||||||
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
|
||||||
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
|
||||||
|
|
||||||
def chunked_fix_cross_entropy(
|
|
||||||
source,
|
|
||||||
target,
|
|
||||||
num_items_in_batch: int = None,
|
|
||||||
ignore_index: int = -100,
|
|
||||||
**kwargs,
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
|
||||||
logit_chunks = [ # pylint: disable=unnecessary-comprehension
|
|
||||||
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
|
|
||||||
]
|
|
||||||
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
|
|
||||||
if reduction == "sum":
|
|
||||||
loss = loss / num_items_in_batch
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def for_causal_lm_chunked_loss(
|
|
||||||
logits,
|
|
||||||
labels,
|
|
||||||
vocab_size: int = None, # pylint: disable=unused-argument
|
|
||||||
num_items_in_batch: Optional[int] = None,
|
|
||||||
ignore_index: int = -100,
|
|
||||||
shift_labels: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# skip the upcast to float since we handle that in the chunking loss
|
|
||||||
if shift_labels is None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
labels = F.pad(labels, (0, 1), value=ignore_index)
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
|
|
||||||
# Skip Flattening the tokens
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(logits.device)
|
|
||||||
loss = chunked_fix_cross_entropy(
|
|
||||||
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
|
|
||||||
)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
return for_causal_lm_chunked_loss
|
|
||||||
|
|
||||||
|
|
||||||
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
|
||||||
import transformers.loss.loss_utils
|
|
||||||
|
|
||||||
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
|
|
||||||
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
|
||||||
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
|
||||||
for_causal_lm_chunked_loss
|
|
||||||
)
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
"""Monkeypatch for gemma3 conditional generation forward to fix high loss"""
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3_conditional_generation_forward():
|
|
||||||
# Remove when https://github.com/huggingface/transformers/pull/37208 merged
|
|
||||||
|
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
|
||||||
Gemma3ForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
setattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs", False)
|
|
||||||
|
|
||||||
def unpatch():
|
|
||||||
delattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs")
|
|
||||||
|
|
||||||
return unpatch
|
|
||||||
@@ -35,7 +35,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"deepseek_v3",
|
"deepseek_v3",
|
||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
"smollm3",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -43,10 +42,6 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
|||||||
if has_remote_code:
|
if has_remote_code:
|
||||||
patch_remote(model_name)
|
patch_remote(model_name)
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
# sanity check in case upstream api changes on this
|
|
||||||
assert hasattr(
|
|
||||||
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
|
||||||
), "transformers api changed for _get_unpad_data for flash attention"
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user