Compare commits
55 Commits
codecov-pu
...
release-v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c620a218b8 | ||
|
|
c6d69d5c1b | ||
|
|
4ff96a2526 | ||
|
|
89e99eaaa7 | ||
|
|
6ed501f6dc | ||
|
|
8c6a6ea6eb | ||
|
|
78bff4925e | ||
|
|
b237c8a3f3 | ||
|
|
1032e22650 | ||
|
|
d68cc1e8ab | ||
|
|
21f1bf4805 | ||
|
|
de2c5ba103 | ||
|
|
9c0d7ee761 | ||
|
|
22d4a838dc | ||
|
|
a108e5db56 | ||
|
|
faff0cff41 | ||
|
|
759cefb741 | ||
|
|
69cd49a7aa | ||
|
|
5a961ecadf | ||
|
|
b37ddf9778 | ||
|
|
bf38e507fb | ||
|
|
a5946ff1f0 | ||
|
|
70ca1b2291 | ||
|
|
8ae5a2311b | ||
|
|
6383630155 | ||
|
|
f2b352f2e5 | ||
|
|
bf5928d0ee | ||
|
|
d1224db8f4 | ||
|
|
327b4e48e9 | ||
|
|
35fdbce102 | ||
|
|
cb811f8bf1 | ||
|
|
7563e1bd30 | ||
|
|
81893c775c | ||
|
|
a1a740608d | ||
|
|
ec15a7a691 | ||
|
|
0a7a216b60 | ||
|
|
d8280d45c1 | ||
|
|
24f2887e87 | ||
|
|
29289a4de9 | ||
|
|
a24957fa04 | ||
|
|
927bf530bc | ||
|
|
18954ba100 | ||
|
|
d8cf66edbd | ||
|
|
181cc3106b | ||
|
|
20106116da | ||
|
|
a27c4f8771 | ||
|
|
bb1109b81d | ||
|
|
8c69ec3a1e | ||
|
|
46675496a3 | ||
|
|
c6b5d35e5d | ||
|
|
12c826816d | ||
|
|
1d8f500709 | ||
|
|
0494359c6c | ||
|
|
26c39e1ca7 | ||
|
|
45adf1bfb9 |
14
.github/workflows/base.yml
vendored
14
.github/workflows/base.yml
vendored
@@ -5,11 +5,13 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- 'Dockerfile-base'
|
||||
- 'docker/Dockerfile-base'
|
||||
- 'docker/Dockerfile-uv-base'
|
||||
- '.github/workflows/base.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'Dockerfile-base'
|
||||
- 'docker/Dockerfile-base'
|
||||
- 'docker/Dockerfile-uv-base'
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
@@ -27,11 +29,11 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
@@ -41,7 +43,7 @@ jobs:
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
|
||||
33
.github/workflows/main.yml
vendored
33
.github/workflows/main.yml
vendored
@@ -15,17 +15,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -83,17 +82,17 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -146,8 +145,8 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
|
||||
11
.github/workflows/multi-gpu-e2e.yml
vendored
11
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,17 +26,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
|
||||
11
.github/workflows/nightlies.yml
vendored
11
.github/workflows/nightlies.yml
vendored
@@ -12,11 +12,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -68,10 +63,10 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
|
||||
123
.github/workflows/tests-nightly.yml
vendored
123
.github/workflows/tests-nightly.yml
vendored
@@ -18,116 +18,26 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
preload-cache:
|
||||
name: Preload HF cache
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
env:
|
||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v tests/conftest.py
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||
pytorch_version: ["2.6.0", "2.7.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -168,15 +78,11 @@ jobs:
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -193,15 +99,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
|
||||
90
.github/workflows/tests.yml
vendored
90
.github/workflows/tests.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -102,16 +102,17 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
- name: Upload coverage artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
name: coverage-${{ matrix.pytorch_version }}-${{ github.run_id }}
|
||||
path: ./coverage.xml
|
||||
retention-days: 1
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -124,7 +125,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -174,9 +175,9 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v tests/patched/
|
||||
pytest -v tests/cli/
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
@@ -194,12 +195,12 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
@@ -233,14 +234,6 @@ jobs:
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
- name: Upload coverage artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-e2e-1st-${{ github.run_id }}
|
||||
path: ./e2e-coverage.xml
|
||||
retention-days: 1
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
@@ -254,22 +247,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: llmcompressor
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
@@ -304,14 +285,6 @@ jobs:
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
- name: Upload coverage artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-e2e-${{ matrix.cuda }}-${{ matrix.pytorch }}-${{ github.run_id }}
|
||||
path: ./e2e-coverage.xml
|
||||
retention-days: 1
|
||||
|
||||
docker-e2e-cleanup:
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
@@ -326,7 +299,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -351,26 +324,3 @@ jobs:
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
upload-coverage:
|
||||
name: Upload Coverage to Codecov
|
||||
runs-on: ubuntu-latest
|
||||
needs: [pytest, docker-e2e-tests, docker-e2e-tests-1st]
|
||||
if: github.event_name == 'pull_request' || github.ref == 'refs/heads/main'
|
||||
|
||||
steps:
|
||||
- name: Download coverage reports
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: coverage-reports
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
directory: coverage-reports
|
||||
fail_ci_if_error: false
|
||||
verbose: true
|
||||
name: codecov-umbrella
|
||||
override_commit: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
override_pr: ${{ github.event.pull_request.number }}
|
||||
|
||||
@@ -19,7 +19,7 @@ repos:
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.2.0
|
||||
rev: 7.3.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
@@ -27,7 +27,7 @@ repos:
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.16.0
|
||||
rev: v1.16.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -36,7 +36,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.3
|
||||
rev: 1.8.6
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -2,4 +2,5 @@ include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||
recursive-include axolotl *.py
|
||||
|
||||
13
README.md
13
README.md
@@ -43,7 +43,7 @@ Features:
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
@@ -55,10 +55,12 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.5.1
|
||||
- PyTorch ≥2.6.0
|
||||
|
||||
### Installation
|
||||
|
||||
#### Using pip
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
@@ -68,6 +70,13 @@ axolotl fetch examples
|
||||
axolotl fetch deepspeed_configs # OPTIONAL
|
||||
```
|
||||
|
||||
#### Using Docker
|
||||
|
||||
Installing with Docker can be less error prone than installing in your own environment.
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
### Your First Fine-tune
|
||||
|
||||
@@ -9,6 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
|
||||
@@ -51,3 +51,5 @@ pytest -v --durations=10 \
|
||||
--cov=axolotl \
|
||||
--cov-append \
|
||||
--cov-report=xml:e2e-coverage.xml
|
||||
|
||||
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
import pathlib
|
||||
|
||||
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
|
||||
|
||||
@@ -14,21 +12,9 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
def cicd_pytest():
|
||||
|
||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||
|
||||
# Read the coverage file if it exists
|
||||
coverage_file = pathlib.Path("/workspace/axolotl/e2e-coverage.xml")
|
||||
if coverage_file.exists():
|
||||
return coverage_file.read_text(encoding="utf-8")
|
||||
return None
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
coverage = cicd_pytest.remote()
|
||||
|
||||
# Save the coverage file to the local filesystem if it was generated
|
||||
if coverage:
|
||||
with open("e2e-coverage.xml", "w", encoding="utf-8") as f:
|
||||
f.write(coverage)
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||
"CUDA": os.environ.get("CUDA", "124"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
@@ -77,18 +77,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
def cicd_pytest():
|
||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||
|
||||
# Read the coverage file if it exists
|
||||
coverage_file = pathlib.Path("/workspace/axolotl/multigpu-coverage.xml")
|
||||
if coverage_file.exists():
|
||||
return coverage_file.read_text(encoding="utf-8")
|
||||
return None
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
coverage = cicd_pytest.remote()
|
||||
|
||||
# Save the coverage file to the local filesystem if it was generated
|
||||
if coverage:
|
||||
with open("multigpu-coverage.xml", "w", encoding="utf-8") as file:
|
||||
file.write(coverage)
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -24,14 +24,16 @@ df_template = template_env.get_template(dockerfile)
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||
"CUDA": os.environ.get("CUDA", "124"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||
"CUDA": os.environ.get("CUDA", "126"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
pip3 install flash-attn==2.7.4.post1; \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||
fi
|
||||
|
||||
@@ -34,7 +34,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
@@ -7,6 +7,7 @@ toc-depth: 3
|
||||
```{python}
|
||||
#| echo: false
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
def process_readme(integration_name):
|
||||
@@ -53,6 +54,24 @@ sections = [
|
||||
("LLMCompressor", "llm_compressor")
|
||||
]
|
||||
|
||||
for folder_name in os.listdir("../src/axolotl/integrations/"):
|
||||
if folder_name in [path for name, path in sections]:
|
||||
# skip if already in sections
|
||||
continue
|
||||
if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"):
|
||||
# grab the first heading in README.md as the section name
|
||||
with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f:
|
||||
txt = f.read()
|
||||
matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE)
|
||||
if matches:
|
||||
name = matches.group(1)
|
||||
else:
|
||||
continue
|
||||
sections.append((name, folder_name))
|
||||
|
||||
# sort sections by name
|
||||
sections = sorted(sections, key=lambda x: x[0])
|
||||
|
||||
for section_name, folder_name in sections:
|
||||
print(print_section(section_name, folder_name))
|
||||
```
|
||||
|
||||
@@ -9,7 +9,7 @@ order: 3
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
||||
```
|
||||
|
||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||
|
||||
@@ -9,7 +9,7 @@ format:
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
## Base
|
||||
@@ -34,8 +34,8 @@ Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.5.1`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -73,13 +73,14 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-py3.11-cu126-2.7.0`
|
||||
- `main-py3.11-cu128-2.7.1`
|
||||
- `main-py3.11-cu126-2.7.1`
|
||||
- `main-py3.11-cu126-2.6.0`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu124-2.5.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu124-2.5.1`
|
||||
- `0.9.2`
|
||||
- `main-20250303-py3.11-cu126-2.6.0`
|
||||
- `0.10.1`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
16
docs/faq.qmd
16
docs/faq.qmd
@@ -9,11 +9,11 @@ description: Frequently asked questions
|
||||
|
||||
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
||||
|
||||
**Q: Exitcode -9**
|
||||
**Q: exitcode: -9**
|
||||
|
||||
> A: This usually happens when you run out of system RAM.
|
||||
|
||||
**Q: Exitcode -7 while using deepspeed**
|
||||
**Q: exitcode: -7 while using deepspeed**
|
||||
|
||||
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
@@ -51,6 +51,18 @@ description: Frequently asked questions
|
||||
> pad_token: "..."
|
||||
> ```
|
||||
|
||||
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
|
||||
|
||||
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
|
||||
|
||||
**Q: vLLM is not working with Axolotl**
|
||||
|
||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
||||
|
||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||
|
||||
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||
|
||||
1. Set `adapter: qlora` in your axolotl config file.
|
||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
||||
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||
|
||||
## Example Config
|
||||
|
||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.11
|
||||
- PyTorch ≥2.5.1
|
||||
- PyTorch ≥2.6.0
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
|
||||
@@ -66,6 +66,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
Using ZeRO Stage 3 with Single-GPU training
|
||||
|
||||
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
|
||||
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
|
||||
|
||||
:::
|
||||
|
||||
## FSDP {#sec-fsdp}
|
||||
|
||||
### Basic FSDP Configuration {#sec-fsdp-config}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
69
examples/devstral/README.md
Normal file
69
examples/devstral/README.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Finetune Devstral with Axolotl
|
||||
|
||||
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||
|
||||
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of upto 128k tokens.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install the latest mistral-common from source
|
||||
pip3 uninstall mistral-common
|
||||
pip3 install git+https://github.com/mistralai/mistral-common.git@039465d
|
||||
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/devstral/devstral-small-qlora.yml
|
||||
```
|
||||
|
||||
This config uses about 21GB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
64
examples/devstral/devstral-small-qlora.yml
Normal file
64
examples/devstral/devstral-small-qlora.yml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: mistralai/Devstral-Small-2505
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_linear: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_ratio: 0.05
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
71
examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: tiiuae/Falcon-H1-1.5B-Deep-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
71
examples/falcon-h1/falcon-h1-1b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-1b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: tiiuae/Falcon-H1-1.5B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
71
examples/falcon-h1/falcon-h1-34b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-34b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: tiiuae/Falcon-H1-34B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
71
examples/falcon-h1/falcon-h1-3b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-3b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: tiiuae/Falcon-H1-3B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
71
examples/falcon-h1/falcon-h1-500m-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-500m-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: tiiuae/Falcon-H1-0.5B-Instruct
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
71
examples/falcon-h1/falcon-h1-7b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-7b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
||||
base_model: tiiuae/Falcon-H1-7B-Base
|
||||
# optionally might have model_type or tokenizer_type
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: falcon_h1
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- in_proj
|
||||
- gate_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -13,6 +13,8 @@ load_in_4bit: true
|
||||
|
||||
# huggingface repo
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
|
||||
@@ -6,6 +6,8 @@ load_in_4bit: true
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
|
||||
@@ -12,6 +12,8 @@ sample_packing: false
|
||||
ddp_find_unused_parameters: true
|
||||
|
||||
chat_template: gemma3
|
||||
eot_tokens:
|
||||
- <end_of_turn>
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
|
||||
@@ -18,16 +18,10 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
```
|
||||
|
||||
2. Download the example config:
|
||||
|
||||
```bash
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||
@@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
@@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
||||
In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
|
||||
55
examples/qwen2_5-vl/lora-7b.yaml
Normal file
55
examples/qwen2_5-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,55 @@
|
||||
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen2_vl
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
field_messages: messages
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.4
|
||||
bitsandbytes==0.46.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
@@ -13,12 +13,12 @@ packaging==23.2
|
||||
|
||||
huggingface_hub==0.32.2
|
||||
peft==0.15.2
|
||||
transformers==4.52.4
|
||||
transformers==4.53.1
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.7.0
|
||||
accelerate==1.8.1
|
||||
datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.1
|
||||
trl==0.18.2
|
||||
hf_xet==1.1.2
|
||||
|
||||
optimum==1.16.2
|
||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
mistral-common==1.6.0
|
||||
mistral-common==1.6.3
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
||||
)
|
||||
|
||||
13
setup.py
13
setup.py
@@ -66,8 +66,11 @@ def parse_requirements(extras_require_map):
|
||||
|
||||
if (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.30")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31.post1")
|
||||
extras_require_map["vllm"] = ["vllm>=0.9.0"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
@@ -111,10 +114,10 @@ def get_package_version():
|
||||
|
||||
|
||||
extras_require = {
|
||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||
"flash-attn": ["flash-attn==2.8.0.post2"],
|
||||
"ring-flash-attn": [
|
||||
"flash-attn==2.7.4.post1",
|
||||
"ring-flash-attn>=0.1.4",
|
||||
"flash-attn==2.8.0.post2",
|
||||
"ring-flash-attn>=0.1.5",
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.11.0.dev"
|
||||
__version__ = "0.11.0"
|
||||
|
||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -46,3 +47,8 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except HTTPError:
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -24,7 +23,6 @@ def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
@@ -39,7 +37,6 @@ def do_cli_train(
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
@@ -54,7 +51,6 @@ def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
|
||||
@@ -26,7 +26,9 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
@@ -233,4 +235,15 @@ def load_cfg(
|
||||
setup_comet_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
cfg_to_log = {
|
||||
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||
for k, v in cfg.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
"config:\n%s",
|
||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -9,7 +9,6 @@ from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
@@ -35,7 +34,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
@@ -13,7 +13,6 @@ from dotenv import load_dotenv
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.utils.chat_templates import (
|
||||
@@ -255,7 +254,6 @@ def do_cli(
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||
|
||||
@@ -20,6 +20,7 @@ from axolotl.cli.args import (
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
@@ -40,6 +41,7 @@ LOG = get_logger(__name__)
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Union
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -23,8 +22,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -194,7 +193,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||
|
||||
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
@@ -33,10 +32,15 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Preprocessing-specific CLI arguments.
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
|
||||
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
|
||||
if cfg.get("key"):
|
||||
raise ValueError(
|
||||
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
|
||||
)
|
||||
|
||||
if not cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -27,7 +26,6 @@ def do_quantize(
|
||||
config (Union[Path, str]): The path to the config file
|
||||
cli_args (dict): Additional command-line arguments
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
@@ -35,7 +34,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
@@ -75,13 +75,17 @@ def load_datasets(
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
)
|
||||
try:
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
)
|
||||
except AttributeError:
|
||||
# can't sample iterable datasets
|
||||
pass
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
|
||||
162
src/axolotl/core/attention/flex_block_mask.py
Normal file
162
src/axolotl/core/attention/flex_block_mask.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
monkeypatch for flex + packing
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
from transformers import Cache, PretrainedConfig
|
||||
from transformers.masking_utils import (
|
||||
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||
_preprocess_mask_arguments,
|
||||
and_masks,
|
||||
causal_mask_function,
|
||||
or_masks,
|
||||
)
|
||||
from transformers.utils import is_torch_greater_or_equal
|
||||
|
||||
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
config: PretrainedConfig,
|
||||
input_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Optional[Cache],
|
||||
or_mask_function: Optional[Callable] = None,
|
||||
and_mask_function: Optional[Callable] = None,
|
||||
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||
"""
|
||||
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
||||
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
||||
to what is needed in the `modeling_xxx.py` files).
|
||||
|
||||
Args:
|
||||
config (`PretrainedConfig`):
|
||||
The model config.
|
||||
input_embeds (`torch.Tensor`):
|
||||
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
||||
batch size, query length and dtype.
|
||||
attention_mask (`torch.Tensor`, optional):
|
||||
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
||||
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
||||
cache_position (`torch.Tensor`):
|
||||
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||
past_key_values (`Cache`, optional):
|
||||
The past key values, if we use a cache.
|
||||
or_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
and_mask_function (`Callable`, optional):
|
||||
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
||||
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||
"""
|
||||
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||
if (
|
||||
past_key_values
|
||||
and hasattr(past_key_values, "is_sliding")
|
||||
and False in past_key_values.is_sliding
|
||||
):
|
||||
layer_idx = past_key_values.is_sliding.index(False)
|
||||
else:
|
||||
layer_idx = 0
|
||||
|
||||
original_attention_mask = (
|
||||
None
|
||||
if attention_mask is None
|
||||
else attention_mask.clone().to(cache_position.device)
|
||||
)
|
||||
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||
)
|
||||
if early_exit:
|
||||
return attention_mask
|
||||
|
||||
batch_size, total_seq_len = cache_position.shape
|
||||
key_length = total_seq_len
|
||||
document_ids = torch.nn.functional.pad(
|
||||
original_attention_mask, value=0, pad=(0, key_length)
|
||||
)
|
||||
|
||||
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
||||
if attention_mask is not None:
|
||||
|
||||
def causal_doc_mask_mod(
|
||||
batch_idx, head_idx, q_idx, kv_idx
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||
and a block diagonal document mask.
|
||||
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||
for an illustration.
|
||||
"""
|
||||
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
||||
document_mask = (
|
||||
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||
)
|
||||
final_mask = causal_mask_ & document_mask
|
||||
return final_mask
|
||||
|
||||
mask_factory_function = causal_doc_mask_mod
|
||||
else:
|
||||
mask_factory_function = causal_mask_function
|
||||
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||
config._attn_implementation # pylint: disable=protected-access
|
||||
]
|
||||
|
||||
# Do not allow skip if we are compiling (this is to match BC)
|
||||
allow_is_causal_skip = (
|
||||
not past_key_values.is_compileable if past_key_values is not None else True
|
||||
)
|
||||
|
||||
# Allow slight deviations from causal mask
|
||||
if or_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
if and_mask_function is not None:
|
||||
if not _is_torch_greater_or_equal_than_2_6:
|
||||
raise ValueError(
|
||||
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||
)
|
||||
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||
allow_is_causal_skip = False
|
||||
|
||||
# We now create the mask
|
||||
causal_mask = mask_interface(
|
||||
batch_size=batch_size,
|
||||
cache_position=cache_position,
|
||||
kv_length=kv_length,
|
||||
kv_offset=kv_offset,
|
||||
mask_function=mask_factory_function,
|
||||
attention_mask=attention_mask,
|
||||
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||
dtype=dtype, # Additional kwarg for eager
|
||||
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||
)
|
||||
return causal_mask
|
||||
|
||||
|
||||
def patch_create_causal_mask(model_type):
|
||||
import transformers.masking_utils
|
||||
|
||||
transformers.masking_utils.create_causal_mask = create_causal_mask
|
||||
|
||||
if model_type:
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(module_path)
|
||||
module.create_causal_mask = create_causal_mask
|
||||
del sys.modules[module_path]
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -219,7 +219,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||
bf16 = self.cfg.bf16 or self.cfg.bfloat16
|
||||
bf16 = bf16 if bf16 is not None else False
|
||||
training_args_kwargs["bf16"] = bf16
|
||||
|
||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||
|
||||
@@ -245,14 +245,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.xformers_attention
|
||||
or self.cfg.flex_attention
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.multipack_real_batches
|
||||
if self.cfg.multipack_real_batches is not None
|
||||
else not self.cfg.flash_attention
|
||||
else not (
|
||||
self.cfg.flash_attention
|
||||
or self.cfg.flex_attention
|
||||
or self.cfg.xformers_attention
|
||||
)
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||
self.cfg.eval_sample_packing
|
||||
)
|
||||
if self.cfg.sample_packing_sequentially is not None:
|
||||
training_arguments_kwargs["sample_packing_sequentially"] = (
|
||||
self.cfg.sample_packing_sequentially
|
||||
)
|
||||
if self.cfg.sample_packing_bin_size is not None:
|
||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||
self.cfg.sample_packing_bin_size
|
||||
@@ -413,7 +426,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
@@ -20,13 +20,14 @@ from torch.utils.data import (
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
CheckpointSaveMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
@@ -42,7 +43,12 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(
|
||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
||||
PackingMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
|
||||
@@ -116,14 +122,15 @@ class AxolotlTrainer(
|
||||
sequential=self.args.sample_packing_sequentially,
|
||||
drop_last=True,
|
||||
num_processes=self.args.dataset_num_proc,
|
||||
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||
)
|
||||
|
||||
len(sampler)
|
||||
return sampler
|
||||
|
||||
def _get_train_sampler(
|
||||
self, train_dataset: Optional[Dataset] = None
|
||||
) -> Optional[Sampler]:
|
||||
self, train_dataset: Dataset | None = None
|
||||
) -> Sampler | None:
|
||||
"""
|
||||
Helper method to get the sampler for training. Handles cases for sample packing
|
||||
and curriculum sampling (sequential).
|
||||
@@ -132,16 +139,22 @@ class AxolotlTrainer(
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
|
||||
if train_dataset is None:
|
||||
train_dataset = self.train_dataset
|
||||
if train_dataset is None or not has_length(train_dataset):
|
||||
return None
|
||||
|
||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||
|
||||
# Determine the base sampler first
|
||||
if self.args.curriculum_sampling:
|
||||
base_sampler = SequentialSampler(self.train_dataset)
|
||||
base_sampler = SequentialSampler(train_dataset)
|
||||
elif use_sample_packing:
|
||||
base_sampler = RandomSampler(self.train_dataset)
|
||||
base_sampler = RandomSampler(train_dataset)
|
||||
else:
|
||||
# Default to parent class implementation for standard random sampling
|
||||
return super()._get_train_sampler()
|
||||
return super()._get_train_sampler(train_dataset)
|
||||
|
||||
# Apply multipack wrapper if needed
|
||||
if use_sample_packing:
|
||||
@@ -160,6 +173,10 @@ class AxolotlTrainer(
|
||||
If the dataset is non-empty, a sampler is returned, the type of which
|
||||
depends on the passed training args.
|
||||
"""
|
||||
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
|
||||
if eval_dataset is None or not has_length(eval_dataset):
|
||||
return None
|
||||
|
||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||
use_multipack = (
|
||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||
@@ -195,6 +212,14 @@ class AxolotlTrainer(
|
||||
|
||||
if dataset.column_names and "length" in dataset.column_names:
|
||||
dataset = dataset.remove_columns(["length"])
|
||||
if (
|
||||
dataset.column_names
|
||||
and "position_ids" in dataset.column_names
|
||||
and "attention_mask" in dataset.column_names
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
dataset = dataset.remove_columns(["attention_mask"])
|
||||
|
||||
if isinstance(dataset, datasets.Dataset):
|
||||
if is_training:
|
||||
|
||||
@@ -28,7 +28,7 @@ class DPOStrategy:
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
|
||||
@@ -5,5 +5,6 @@
|
||||
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
from .scheduler import SchedulerMixin
|
||||
|
||||
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Trainer mixin to support packing"""
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class PackingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin to support packing
|
||||
"""
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
super()._set_signature_columns_if_needed()
|
||||
if (
|
||||
self._signature_columns
|
||||
and self.args.sample_packing
|
||||
and self.args.sample_packing_drop_attention_mask
|
||||
):
|
||||
set_sig_columns = set(self._signature_columns)
|
||||
set_sig_columns.remove("attention_mask")
|
||||
self._signature_columns = list(set_sig_columns)
|
||||
@@ -38,6 +38,14 @@ class AxolotlTrainingMixins:
|
||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||
},
|
||||
)
|
||||
sample_packing_mp_start_method: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The multiprocessing start method to use."},
|
||||
)
|
||||
sample_packing_drop_attention_mask: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Drop attention mask from inputs when using packing."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
|
||||
@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||
|
||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
||||
LOG.info(
|
||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
||||
)
|
||||
num_proc = 1
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers import PreTrainedModel, Trainer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
|
||||
@@ -19,19 +19,11 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
|
||||
|
||||
```bash
|
||||
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
|
||||
|
||||
pip3 install --no-build-isolation -e .
|
||||
```
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
@@ -39,27 +31,29 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- mllama
|
||||
- phi3
|
||||
- cohere
|
||||
- cohere2
|
||||
- gemma
|
||||
- gemma2
|
||||
- gemma3
|
||||
- gemma3_text
|
||||
- glm
|
||||
- glm4
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- mistral
|
||||
- mistral3
|
||||
- mllama
|
||||
- phi
|
||||
- phi3
|
||||
- phi4_multimodal
|
||||
- qwen2
|
||||
- qwen2_moe
|
||||
- qwen2_vl
|
||||
- qwen2_moe
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- cohere
|
||||
- cohere2
|
||||
- glm
|
||||
- glm4
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@@ -28,11 +28,11 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
||||
)
|
||||
|
||||
|
||||
@@ -64,16 +64,28 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
"cut_cross_entropy.transformers"
|
||||
)
|
||||
if cce_spec_transformers is None:
|
||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||
raise ImportError(
|
||||
"Transformers support is not installed. " + _CCE_INSTALL_MESSAGE
|
||||
)
|
||||
|
||||
# Check if Axolotl's cce fork is installed
|
||||
try:
|
||||
from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK
|
||||
|
||||
if not AXOLOTL_CCE_FORK:
|
||||
raise ImportError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Axolotl's fork of cut_cross_entropy is not installed. "
|
||||
+ _CCE_INSTALL_MESSAGE
|
||||
) from e
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply cut cross entropy before model loading if enabled."""
|
||||
if cfg.cut_cross_entropy:
|
||||
self._check_requirements()
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
||||
cce_patch,
|
||||
)
|
||||
from cut_cross_entropy.transformers.patch import cce_patch
|
||||
|
||||
LOG.info(
|
||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Cohere and Cohere2 CCE patch."""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
||||
# operation is done internally and should be mathematically equivalent.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.cohere.modeling_cohere import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
||||
|
||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
||||
|
||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>> # Generate
|
||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# scale hidden_states by logit_scale in-place of logits
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :] * self.logit_scale,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
logits = logits * self.logit_scale # main diff from Llama
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_cohere(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere import modeling_cohere
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere.CohereForCausalLM
|
||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_cohere2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.cohere2 import modeling_cohere2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Gemma CCE patch"""
|
||||
|
||||
# This patch is based off transformers 4.50.0.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma.modeling_gemma import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
||||
|
||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,447 +0,0 @@
|
||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
||||
|
||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
||||
# and updated for transformers 4.50.0.
|
||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
||||
# with both gemma3 (text and multimodal) models.
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache, HybridCache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3CausalLMOutputWithPast,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[HybridCache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
||||
|
||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
||||
|
||||
>>> prompt = "What is your favorite condiment?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is your favorite condiment?"
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if self.config.final_logit_softcapping is not None:
|
||||
logits = logits / self.config.final_logit_softcapping
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.config.final_logit_softcapping
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
||||
|
||||
>>> prompt = "answer en Where is the cow standing?"
|
||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"answer en Where is the cow standing?\nbeach"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
is_training = token_type_ids is not None and labels is not None
|
||||
|
||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_index
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids # type: ignore
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
||||
)
|
||||
cache_position = torch.arange( # type: ignore
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values)
|
||||
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(
|
||||
self.config.image_token_index,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
||||
-1
|
||||
)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
||||
raise ValueError(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
# mask out pad-token-ids in labels for BC
|
||||
if labels is not None and self.pad_token_id in labels:
|
||||
logger.warning_once(
|
||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
||||
)
|
||||
labels = torch.where( # type: ignore
|
||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
||||
)
|
||||
|
||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
token_type_ids,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
inputs_embeds,
|
||||
is_training,
|
||||
)
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = shift_logits[
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = shift_labels[
|
||||
shift_attention_mask.to(shift_labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = shift_logits.contiguous()
|
||||
shift_labels = shift_labels.contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
|
||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
loss = loss_fct(flat_logits, flat_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Gemma3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_gemma2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_gemma3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.gemma3 import modeling_gemma3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,57 +0,0 @@
|
||||
"""GLM 4 patch. GLM family inherits from Llama."""
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_glm(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm import modeling_glm
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm.GlmForCausalLM
|
||||
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm.GlmForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_glm4(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import cut_cross_entropy.transformers.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from cut_cross_entropy.transformers.llama import cce_forward
|
||||
from transformers.models.glm4 import modeling_glm4
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_glm4.Glm4ForCausalLM
|
||||
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: BaseModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_llama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
"""Patch Llama for CCE."""
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama.LlamaForCausalLM
|
||||
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,401 +0,0 @@
|
||||
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama4.modeling_llama4 import (
|
||||
Llama4CausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
||||
|
||||
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
|
||||
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
|
||||
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_config.vision_feature_layer
|
||||
)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_config.vision_feature_select_strategy
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
||||
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
||||
|
||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
||||
num_tokens_to_fill = final_mask_1d.sum()
|
||||
|
||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
||||
raise ValueError(
|
||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
||||
)
|
||||
|
||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
expanded_mask, projected_vision_flat
|
||||
) # type: ignore
|
||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
# TODO: check if need to handle attention_mask
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Llama4CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits, # type: ignore # TODO: check if need to create dummy logits
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_llama4_text(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama4.Llama4ForCausalLM
|
||||
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForCausalLM,
|
||||
"forward",
|
||||
cce_forward,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def patch_llama4(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.llama4 import modeling_llama4
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
|
||||
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
setattr(
|
||||
modeling_llama4.Llama4ForConditionalGeneration,
|
||||
"forward",
|
||||
cce_forward_multimodal,
|
||||
)
|
||||
|
||||
# patch the causal language model
|
||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
||||
return None
|
||||
@@ -1,384 +0,0 @@
|
||||
"""Mistral and Mistral3 CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mistral3.modeling_mistral3 import (
|
||||
Mistral3CausalLMOutputWithPast,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
KwargsForCausalLM,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import (
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] | None = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
||||
|
||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
image_sizes: torch.Tensor | None = None,
|
||||
**lm_kwargs,
|
||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||
|
||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(
|
||||
pixel_values=pixel_values,
|
||||
vision_feature_layer=vision_feature_layer,
|
||||
image_sizes=image_sizes,
|
||||
)
|
||||
|
||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
if (
|
||||
not is_torchdynamo_compiling()
|
||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
||||
):
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**lm_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**lm_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = hidden_states
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
if attention_mask is not None:
|
||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
||||
logits.device
|
||||
)
|
||||
shift_logits = logits[..., :-1, :][
|
||||
shift_attention_mask.to(logits.device) != 0
|
||||
].contiguous()
|
||||
shift_labels = labels[..., 1:][
|
||||
shift_attention_mask.to(labels.device) != 0
|
||||
].contiguous()
|
||||
else:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1).to(shift_logits.device),
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Mistral3CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral.MistralForCausalLM
|
||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
|
||||
|
||||
def patch_mistral3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
from transformers.models.mistral3 import modeling_mistral3
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
||||
# patch the causal model to enable deferred logits calculation
|
||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,366 +0,0 @@
|
||||
"""Mllama CCE patch."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.mllama.modeling_mllama import (
|
||||
_prepare_cross_attention_mask,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
defer_logits_calculation: bool = False,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
defer_logits_calculation (`bool`, *optional*):
|
||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
||||
|
||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
||||
|
||||
>>> prompt = "If I had to write a haiku, it would be:"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
>>> print(result)
|
||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
||||
I love the idea of snowflakes gently falling, each one
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
||||
# defer logits calculation to the ConditionalGeneration forward
|
||||
logits = hidden_states[:, slice_indices, :]
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
||||
|
||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
|
||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
||||
|
||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
||||
>>> generated_ids = output[:, prompt_len:]
|
||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
>>> print(generated_text)
|
||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
||||
```
|
||||
"""
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if pixel_values is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if pixel_values is not None and cross_attention_states is not None:
|
||||
raise ValueError(
|
||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
||||
)
|
||||
|
||||
if pixel_values is not None:
|
||||
if aspect_ratio_ids is None:
|
||||
raise ValueError(
|
||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
||||
)
|
||||
# get vision tokens from vision model
|
||||
vision_outputs = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
cross_attention_states = vision_outputs[0]
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states
|
||||
).reshape(
|
||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
||||
)
|
||||
|
||||
if cross_attention_mask is not None:
|
||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
||||
_prepare_cross_attention_mask(
|
||||
cross_attention_mask,
|
||||
num_vision_tokens=self.vision_model.num_patches,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_text_row_masked_out_mask = None
|
||||
|
||||
if cross_attention_mask is not None and cache_position is not None:
|
||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
||||
:, :, cache_position
|
||||
]
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
cross_attention_states=cross_attention_states,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
logits_to_keep=logits_to_keep,
|
||||
defer_logits_calculation=True, # enable deferred logits calculation
|
||||
**loss_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.language_model.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
||||
logits = hidden_states
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (loss,) + outputs if loss is not None else outputs
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=outputs.logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def patch_mllama(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
from transformers.models.mllama import modeling_mllama
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
# patch the language model
|
||||
maybe_model.language_model.forward = MethodType(
|
||||
cce_forward, maybe_model.language_model
|
||||
)
|
||||
return maybe_model
|
||||
|
||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
||||
|
||||
# patch the causal language model
|
||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Cut Cross Entropy patcher"""
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||
patch_cohere,
|
||||
patch_cohere2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
||||
patch_gemma2,
|
||||
patch_gemma3,
|
||||
patch_gemma3_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
||||
patch_glm,
|
||||
patch_glm4,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||
patch_llama,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
||||
patch_llama4,
|
||||
patch_llama4_text,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
||||
patch_mistral,
|
||||
patch_mistral3,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
|
||||
patch_qwen2,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
|
||||
patch_qwen2_5_vl,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
|
||||
patch_qwen2_moe,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
|
||||
patch_qwen2_vl,
|
||||
)
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
|
||||
patch_qwen3_moe,
|
||||
)
|
||||
|
||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||
"llama": patch_llama,
|
||||
"llama4": patch_llama4,
|
||||
"llama4_text": patch_llama4_text,
|
||||
"mllama": patch_mllama,
|
||||
"phi3": patch_phi3,
|
||||
"gemma": patch_gemma,
|
||||
"gemma2": patch_gemma2,
|
||||
"gemma3": patch_gemma3,
|
||||
"gemma3_text": patch_gemma3_text,
|
||||
"mistral": patch_mistral,
|
||||
"mistral3": patch_mistral3,
|
||||
"qwen2": patch_qwen2,
|
||||
"qwen2_moe": patch_qwen2_moe,
|
||||
"qwen2_vl": patch_qwen2_vl,
|
||||
"qwen2_5_vl": patch_qwen2_5_vl,
|
||||
"qwen3": patch_qwen3,
|
||||
"qwen3_moe": patch_qwen3_moe,
|
||||
"cohere": patch_cohere,
|
||||
"cohere2": patch_cohere2,
|
||||
"glm": patch_glm,
|
||||
"glm4": patch_glm4,
|
||||
}
|
||||
|
||||
|
||||
def cce_patch(
|
||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
||||
reduction: str = "mean",
|
||||
filter_eps: float | str | None = "auto",
|
||||
accum_e_fp32: bool = False,
|
||||
accum_c_fp32: bool = False,
|
||||
filter_e_grad: bool = True,
|
||||
filter_c_grad: bool = True,
|
||||
train_only: bool = False,
|
||||
) -> TransformersModelT | None:
|
||||
if isinstance(impl, LinearCrossEntropyImpl):
|
||||
impl = impl.name.lower()
|
||||
|
||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
||||
raise ValueError(f"Unknown {impl=}")
|
||||
|
||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
||||
if hasattr(model_type_or_model, "config"):
|
||||
model_type = getattr(
|
||||
getattr(model_type_or_model, "config", None), "model_type", None
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
|
||||
)
|
||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
||||
model_type = model_type_or_model.model_type
|
||||
else:
|
||||
model_type = model_type_or_model
|
||||
|
||||
patch_options = PatchOptions(
|
||||
impl=impl,
|
||||
reduction=reduction,
|
||||
filter_eps=filter_eps,
|
||||
accum_e_fp32=accum_e_fp32,
|
||||
accum_c_fp32=accum_c_fp32,
|
||||
filter_e_grad=filter_e_grad,
|
||||
filter_c_grad=filter_c_grad,
|
||||
train_only=train_only,
|
||||
)
|
||||
|
||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
||||
model_type_or_model, patch_options
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unknown model type {model_type}")
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||
cce_forward,
|
||||
)
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2.Qwen2ForCausalLM
|
||||
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,246 +0,0 @@
|
||||
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
Qwen2_5_VLCausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.image_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
|
||||
mask = input_ids == self.config.video_token_id
|
||||
mask_unsqueezed = mask.unsqueeze(-1)
|
||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids,
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts,
|
||||
attention_mask,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||
position_ids = position_ids.add(delta) # type: ignore
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = None
|
||||
loss = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2_5_VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_5_vl(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
||||
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
|
||||
cce_forward_multimodal
|
||||
)
|
||||
return None
|
||||
@@ -1,178 +0,0 @@
|
||||
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**loss_kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
|
||||
|
||||
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: MoeModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**loss_kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||
loss.device # type: ignore
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss, # type: ignore
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_moe(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
|
||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
|
||||
return None
|
||||
@@ -1,239 +0,0 @@
|
||||
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||
Qwen2VLCausalLMOutputWithPast,
|
||||
)
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
def cce_forward_multimodal(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||
n_image_features = image_embeds.shape[0]
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
image_mask = (
|
||||
(input_ids == self.config.image_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||
n_video_features = video_embeds.shape[0]
|
||||
if n_video_tokens != n_video_features:
|
||||
raise ValueError(
|
||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||
)
|
||||
video_mask = (
|
||||
(input_ids == self.config.video_token_id)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||
|
||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||
# calculate RoPE index once per generation in the pre-fill stage only
|
||||
if (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or self.rope_deltas is None
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||
):
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
cache_position[0] + self.rope_deltas
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||
delta = delta.to(position_ids.device) # type: ignore
|
||||
position_ids = position_ids.add(delta) # type: ignore
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=None,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = None
|
||||
loss = None
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states,
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if labels is not None:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return Qwen2VLCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
rope_deltas=self.rope_deltas,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen2_vl(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
|
||||
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
|
||||
return None
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen3(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
from transformers.models.qwen3 import modeling_qwen3
|
||||
|
||||
# Set the _PATCH_OPTS in the llama patch file
|
||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||
|
||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||
|
||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen3.Qwen3ForCausalLM
|
||||
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
|
||||
return None
|
||||
@@ -1,183 +0,0 @@
|
||||
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from types import MethodType
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from cut_cross_entropy.transformers.utils import (
|
||||
PatchOptions,
|
||||
TransformersModelT,
|
||||
apply_lce,
|
||||
)
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||
KwargsForCausalLM,
|
||||
MoeCausalLMOutputWithPast,
|
||||
MoeModelOutputWithPast,
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
from transformers.utils.generic import can_return_tuple
|
||||
|
||||
_PATCH_OPTS: PatchOptions | None = None
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs: Unpack[KwargsForCausalLM],
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
||||
|
||||
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs: MoeModelOutputWithPast = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
if hidden_states is None:
|
||||
raise ValueError("hidden_states is None")
|
||||
|
||||
loss = None
|
||||
logits = None
|
||||
|
||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
|
||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||
assert labels is not None
|
||||
loss = apply_lce(
|
||||
hidden_states[:, slice_indices, :],
|
||||
self.lm_head.weight,
|
||||
labels,
|
||||
_PATCH_OPTS,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
|
||||
if labels is not None:
|
||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||
loss.device # type: ignore
|
||||
) # make sure to reside in the same device
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss, # type: ignore
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def patch_qwen3_moe(
|
||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||
patch_options: PatchOptions,
|
||||
) -> TransformersModelT | None:
|
||||
global _PATCH_OPTS # pylint: disable=global-statement
|
||||
|
||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||
|
||||
_PATCH_OPTS = patch_options
|
||||
|
||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||
assert isinstance(
|
||||
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
|
||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||
maybe_model.forward = MethodType(forward, maybe_model)
|
||||
|
||||
return maybe_model
|
||||
|
||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
|
||||
return None
|
||||
@@ -1,40 +0,0 @@
|
||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||
|
||||
"""Monkeypatch for apply_lce to add softcap."""
|
||||
|
||||
import torch
|
||||
from cut_cross_entropy import linear_cross_entropy
|
||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
||||
|
||||
|
||||
def apply_lce(
|
||||
e: torch.Tensor,
|
||||
c: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
opts: PatchOptions,
|
||||
bias: torch.Tensor | None = None,
|
||||
softcap: float | None = None,
|
||||
**loss_kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
||||
cce_kwargs = opts.to_kwargs()
|
||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
||||
cce_kwargs["reduction"] = "sum"
|
||||
else:
|
||||
num_items_in_batch = None
|
||||
|
||||
loss = linear_cross_entropy(
|
||||
e,
|
||||
c,
|
||||
labels.to(e.device),
|
||||
bias=bias,
|
||||
shift=True,
|
||||
softcap=softcap,
|
||||
**cce_kwargs,
|
||||
)
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
loss = loss / num_items_in_batch
|
||||
|
||||
return loss
|
||||
12
src/axolotl/integrations/densemixer/README.md
Normal file
12
src/axolotl/integrations/densemixer/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# DenseMixer
|
||||
|
||||
See [DenseMixer](https://github.com/yaof20/DenseMixer/)
|
||||
|
||||
# Usage
|
||||
|
||||
Simply add the following to your axolotl YAML config:
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.densemixer.DenseMixerPlugin
|
||||
```
|
||||
5
src/axolotl/integrations/densemixer/__init__.py
Normal file
5
src/axolotl/integrations/densemixer/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Integration entry point for the DenseMixer plugin."""
|
||||
|
||||
from .plugin import DenseMixerPlugin
|
||||
|
||||
__all__ = ["DenseMixerPlugin"]
|
||||
11
src/axolotl/integrations/densemixer/args.py
Normal file
11
src/axolotl/integrations/densemixer/args.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Pydantic models for DenseMixer plugin"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DenseMixerArgs(BaseModel):
|
||||
"""
|
||||
Args for DenseMixer
|
||||
"""
|
||||
|
||||
dense_mixer: bool = True
|
||||
42
src/axolotl/integrations/densemixer/plugin.py
Normal file
42
src/axolotl/integrations/densemixer/plugin.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""DenseMixer plugin for Axolotl"""
|
||||
|
||||
import importlib
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DenseMixerPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for DenseMixer
|
||||
"""
|
||||
|
||||
def get_input_args(self) -> str | None:
|
||||
return "axolotl.integrations.densemixer.args.DenseMixerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""Apply densemixer patches before model loading if enabled."""
|
||||
if cfg.dense_mixer:
|
||||
if not importlib.util.find_spec("densemixer"):
|
||||
raise RuntimeError(
|
||||
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
||||
)
|
||||
|
||||
from densemixer.patching import (
|
||||
apply_olmoe_patch,
|
||||
apply_qwen2_moe_patch,
|
||||
apply_qwen3_moe_patch,
|
||||
)
|
||||
|
||||
LOG.info(
|
||||
f"Applying DenseMixer patches for model type: {cfg.model_config_type}"
|
||||
)
|
||||
|
||||
if cfg.model_config_type == "olmoe":
|
||||
apply_olmoe_patch()
|
||||
if cfg.model_config_type == "qwen2_moe":
|
||||
apply_qwen2_moe_patch()
|
||||
if cfg.model_config_type == "qwen3_moe":
|
||||
apply_qwen3_moe_patch()
|
||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
||||
kd_alpha: 0.9
|
||||
kd_temperature: 1.0
|
||||
|
||||
torch_compile: True # torch>=2.5.1, recommended to reduce vram
|
||||
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
|
||||
@@ -27,7 +27,7 @@ from axolotl.utils.logging import get_logger
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .utils import patch_with_compile_disable
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -504,6 +504,9 @@ class ModelLoader:
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
if self.cfg.model_config_type == "falcon_h1":
|
||||
# output projection cannot be quantized for Falcon-H1 models
|
||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||
|
||||
if self.cfg.bnb_config_kwargs:
|
||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||
@@ -518,6 +521,9 @@ class ModelLoader:
|
||||
# Exclude mamba blocks from int8 quantization for jamba
|
||||
if self.cfg.model_config_type == "jamba":
|
||||
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||
if self.cfg.model_config_type == "falcon_h1":
|
||||
# output projection cannot be quantized for Falcon-H1 models
|
||||
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
@@ -770,6 +776,9 @@ class ModelLoader:
|
||||
dist_dtype: torch.dtype,
|
||||
before_kbit_train_or_finetune: bool,
|
||||
):
|
||||
dest = {"dtype": dist_dtype}
|
||||
if self.cfg.lora_on_cpu:
|
||||
dest["device"] = "cpu"
|
||||
for name, module in self.model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(dist_dtype)
|
||||
@@ -780,4 +789,4 @@ class ModelLoader:
|
||||
# don't upcast lm_head for btlm
|
||||
continue
|
||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||
module.to(dist_dtype)
|
||||
module.to(**dest)
|
||||
|
||||
@@ -7,6 +7,7 @@ import importlib.util
|
||||
from functools import cached_property
|
||||
|
||||
import addict
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
@@ -49,10 +50,11 @@ class PatchManager:
|
||||
|
||||
def apply_pre_model_load_patches(self):
|
||||
"""Apply pre-model load patches based on config."""
|
||||
# self._apply_flex_attention_patches()
|
||||
self._apply_flash_attention_patches()
|
||||
self._apply_chunked_cross_entropy_patch()
|
||||
self._apply_fsdp_patches()
|
||||
self._apply_adapter_patches()
|
||||
self._apply_flex_attention_patches()
|
||||
self._apply_model_specific_patches()
|
||||
self._apply_fp8_patches()
|
||||
self._apply_flash_attention_peft_patches()
|
||||
@@ -63,6 +65,9 @@ class PatchManager:
|
||||
self._patch_llama_derived_model()
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_gemma3_conditional_generation_forward_patch()
|
||||
self._apply_sequence_parallel_patches()
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
@@ -78,6 +83,15 @@ class PatchManager:
|
||||
patch_xformers_attn_over_fa2()
|
||||
self.cfg.flash_attention = True
|
||||
|
||||
def _apply_chunked_cross_entropy_patch(self):
|
||||
if self.cfg.chunked_cross_entropy:
|
||||
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
||||
|
||||
if self.cfg.chunked_cross_entropy_num_chunks:
|
||||
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
|
||||
else:
|
||||
patch_chunked_ce_loss_fn()
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
@@ -85,6 +99,14 @@ class PatchManager:
|
||||
|
||||
patch_accelerate_fsdp2()
|
||||
|
||||
# if self.cfg.fsdp_config:
|
||||
# # see transformers#39152
|
||||
# from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||
# patch_training_loop_for_fsdp,
|
||||
# )
|
||||
#
|
||||
# patch_training_loop_for_fsdp()
|
||||
|
||||
def _apply_adapter_patches(self):
|
||||
"""Apply patches for adapter configurations."""
|
||||
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||
@@ -95,14 +117,20 @@ class PatchManager:
|
||||
def _apply_flex_attention_patches(self):
|
||||
"""Apply patches for flexible attention."""
|
||||
if self.cfg.flex_attention:
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
patch_flex_wrapper,
|
||||
)
|
||||
# from axolotl.monkeypatch.attention.flex_attn import (
|
||||
# patch_flex_make_mask,
|
||||
# patch_flex_wrapper,
|
||||
# )
|
||||
#
|
||||
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
# patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
# patch_flex_make_mask()
|
||||
if self.cfg.sample_packing:
|
||||
from axolotl.core.attention.flex_block_mask import (
|
||||
patch_create_causal_mask,
|
||||
)
|
||||
|
||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||
patch_flex_make_mask()
|
||||
patch_create_causal_mask(self.cfg.model_config_type)
|
||||
|
||||
def _apply_model_specific_patches(self):
|
||||
"""Apply patches specific to model architectures."""
|
||||
@@ -138,10 +166,25 @@ class PatchManager:
|
||||
"""Apply patches for gradient checkpointing."""
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||
CheckpointFunctionWithCPUOffload,
|
||||
hf_grad_checkpoint_offload_wrapper,
|
||||
)
|
||||
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
if (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
|
||||
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
|
||||
):
|
||||
transformers.modeling_utils.checkpoint = (
|
||||
hf_grad_checkpoint_offload_wrapper
|
||||
)
|
||||
else:
|
||||
transformers.modeling_utils.checkpoint.CheckpointFunction = (
|
||||
CheckpointFunctionWithCPUOffload
|
||||
)
|
||||
torch.utils.checkpoint.CheckpointFunction = (
|
||||
CheckpointFunctionWithCPUOffload
|
||||
)
|
||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||
hf_grad_checkpoint_disk_offload_wrapper,
|
||||
@@ -211,6 +254,32 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
def _apply_gemma3_conditional_generation_forward_patch(self):
|
||||
"""Apply gemma3 conditional generation forward patch."""
|
||||
if self.model_config.model_type in ["gemma3", "gemma3_text"]:
|
||||
from axolotl.monkeypatch.models.gemma3.modeling import (
|
||||
patch_gemma3_conditional_generation_forward,
|
||||
)
|
||||
|
||||
patch_gemma3_conditional_generation_forward()
|
||||
|
||||
def _apply_sequence_parallel_patches(self):
|
||||
"""Apply sequence parallelism patches."""
|
||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||
from axolotl.monkeypatch.ring_attn.patch import (
|
||||
patch_prepare_data_loader,
|
||||
patch_prepare_device_mesh,
|
||||
)
|
||||
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||
|
||||
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
|
||||
|
||||
def _patch_attention(self):
|
||||
"""Apply attention-specific patches based on model type."""
|
||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||
|
||||
@@ -273,7 +273,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
if is_main_process(use_environ=True):
|
||||
if is_main_process():
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
|
||||
@@ -5,7 +5,8 @@ from functools import partial
|
||||
|
||||
from packaging import version
|
||||
|
||||
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (
|
||||
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
|
||||
CheckpointFunctionWithCPUOffload,
|
||||
CPU_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
||||
|
||||
@@ -13,8 +13,24 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils.checkpoint import (
|
||||
_get_autocast_kwargs,
|
||||
_get_device_module,
|
||||
_infer_device_type,
|
||||
check_backward_validity,
|
||||
detach_variable,
|
||||
get_device_states,
|
||||
set_device_states,
|
||||
)
|
||||
|
||||
# support different pytorch versions
|
||||
has_device_type = "device_type" in inspect.signature(set_device_states).parameters
|
||||
|
||||
torch_version = version.parse(torch.__version__)
|
||||
|
||||
@@ -60,3 +76,153 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
) + (
|
||||
None,
|
||||
) * len(ctx.args)
|
||||
|
||||
|
||||
# Copyright 2025 Snowflake Inc.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
|
||||
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
|
||||
"""
|
||||
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
|
||||
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.preserve_rng_state = preserve_rng_state
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
ctx.device_type = _infer_device_type(*args)
|
||||
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
||||
ctx.device_type
|
||||
)
|
||||
if preserve_rng_state:
|
||||
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||
# Don't eagerly initialize the cuda context by accident.
|
||||
# (If the user intends that the context is initialized later, within their
|
||||
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||
# we have no way to anticipate this will happen before we run the function.)
|
||||
ctx.had_device_in_fwd = False
|
||||
device_module = _get_device_module(ctx.device_type)
|
||||
if getattr(device_module, "_initialized", False):
|
||||
ctx.had_device_in_fwd = True
|
||||
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
|
||||
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
# x = None
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
# cpu-offload
|
||||
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
|
||||
# upstream could accept a list of arg indices to offload
|
||||
if i == 0:
|
||||
# print(f"{arg.shape=}")
|
||||
ctx.x_device = arg.device
|
||||
ctx.x_requires_grad = arg.requires_grad
|
||||
t = arg.detach().cpu()
|
||||
else:
|
||||
t = arg
|
||||
tensor_inputs.append(t)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if (
|
||||
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
|
||||
):
|
||||
raise RuntimeError(
|
||||
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
|
||||
" with .grad() or passing an `inputs` parameter to .backward()."
|
||||
" To resolve this error, you can either set use_reentrant=False,"
|
||||
" or call .backward() without passing the `inputs` argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
if i == 0:
|
||||
t = (
|
||||
tensors[i]
|
||||
.to(ctx.x_device)
|
||||
.detach()
|
||||
.requires_grad_(ctx.x_requires_grad)
|
||||
)
|
||||
else:
|
||||
t = tensors[i]
|
||||
inputs[idx] = t
|
||||
|
||||
# Stash the surrounding rng state, and mimic the state that was
|
||||
# present at this time during forward. Restore the surrounding state
|
||||
# when we're done.
|
||||
rng_devices = []
|
||||
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
||||
rng_devices = ctx.fwd_devices
|
||||
with torch.random.fork_rng(
|
||||
devices=rng_devices,
|
||||
enabled=ctx.preserve_rng_state,
|
||||
device_type=ctx.device_type,
|
||||
):
|
||||
if ctx.preserve_rng_state:
|
||||
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||
if ctx.had_device_in_fwd:
|
||||
if has_device_type:
|
||||
# newer pytorch (as early as 2.7)
|
||||
set_device_states(
|
||||
ctx.fwd_devices,
|
||||
ctx.fwd_device_states,
|
||||
device_type=ctx.device_type,
|
||||
)
|
||||
else:
|
||||
# older pytorch (at least 2.4)
|
||||
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
|
||||
device_autocast_ctx = (
|
||||
torch.amp.autocast(
|
||||
device_type=ctx.device_type, **ctx.device_autocast_kwargs
|
||||
)
|
||||
if torch.amp.is_autocast_available(ctx.device_type)
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
|
||||
# run backward() with only tensor that requires grad
|
||||
outputs_with_grad = []
|
||||
args_with_grad = []
|
||||
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
|
||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError(
|
||||
"none of output has requires_grad=True, this checkpoint() is not necessary"
|
||||
)
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(
|
||||
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs
|
||||
)
|
||||
|
||||
return (None, None) + grads
|
||||
|
||||
0
src/axolotl/monkeypatch/loss/__init__.py
Normal file
0
src/axolotl/monkeypatch/loss/__init__.py
Normal file
134
src/axolotl/monkeypatch/loss/chunked.py
Normal file
134
src/axolotl/monkeypatch/loss/chunked.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
chunked ce loss
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
|
||||
class CEWithChunkedOutputLoss(torch.nn.Module):
|
||||
"""
|
||||
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
|
||||
|
||||
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
|
||||
"""
|
||||
|
||||
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
super().__init__()
|
||||
self.num_output_chunks = num_output_chunks
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def compute_cross_entropy(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
normalize: bool = True, # pylint: disable=unused-argument
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Upcast logits to fp32 and compute cross entropy loss.
|
||||
"""
|
||||
return F.cross_entropy(
|
||||
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
logits (List[torch.Tensor]): List of chunked logits of length
|
||||
``self.num_output_chunks``, where each chunk has shape
|
||||
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
|
||||
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
|
||||
reduction (str): The reduction to apply to the output.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Cross entropy loss of shape (1,).
|
||||
"""
|
||||
|
||||
total_elements = (labels != self.ignore_index).sum()
|
||||
|
||||
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
|
||||
labels = [
|
||||
target_chunk.reshape(-1)
|
||||
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
|
||||
]
|
||||
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
|
||||
logits = [
|
||||
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
|
||||
]
|
||||
|
||||
# compute one chunk at a time
|
||||
total_loss = 0.0
|
||||
for logits_chunk, labels_chunk in zip(logits, labels):
|
||||
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
|
||||
|
||||
if reduction == "sum":
|
||||
return total_loss
|
||||
return total_loss / total_elements
|
||||
|
||||
|
||||
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
||||
loss_fn_ce.compute_cross_entropy = torch.compile(
|
||||
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
||||
)
|
||||
return loss_fn_ce
|
||||
|
||||
|
||||
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
||||
|
||||
def chunked_fix_cross_entropy(
|
||||
source,
|
||||
target,
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100,
|
||||
**kwargs,
|
||||
): # pylint: disable=unused-argument
|
||||
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||
logit_chunks = [ # pylint: disable=unnecessary-comprehension
|
||||
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
|
||||
]
|
||||
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
|
||||
if reduction == "sum":
|
||||
loss = loss / num_items_in_batch
|
||||
return loss
|
||||
|
||||
def for_causal_lm_chunked_loss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int = None, # pylint: disable=unused-argument
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# skip the upcast to float since we handle that in the chunking loss
|
||||
if shift_labels is None:
|
||||
# Shift so that tokens < n predict n
|
||||
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Skip Flattening the tokens
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
loss = chunked_fix_cross_entropy(
|
||||
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
|
||||
)
|
||||
return loss
|
||||
|
||||
return for_causal_lm_chunked_loss
|
||||
|
||||
|
||||
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||
import transformers.loss.loss_utils
|
||||
|
||||
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
|
||||
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
||||
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
||||
for_causal_lm_chunked_loss
|
||||
)
|
||||
0
src/axolotl/monkeypatch/models/gemma3/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/gemma3/__init__.py
Normal file
16
src/axolotl/monkeypatch/models/gemma3/modeling.py
Normal file
16
src/axolotl/monkeypatch/models/gemma3/modeling.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Monkeypatch for gemma3 conditional generation forward to fix high loss"""
|
||||
|
||||
|
||||
def patch_gemma3_conditional_generation_forward():
|
||||
# Remove when https://github.com/huggingface/transformers/pull/37208 merged
|
||||
|
||||
from transformers.models.gemma3.modeling_gemma3 import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
)
|
||||
|
||||
setattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs", False)
|
||||
|
||||
def unpatch():
|
||||
delattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs")
|
||||
|
||||
return unpatch
|
||||
@@ -35,6 +35,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"deepseek_v3",
|
||||
"glm",
|
||||
"glm4",
|
||||
"smollm3",
|
||||
]
|
||||
|
||||
|
||||
@@ -42,6 +43,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
# sanity check in case upstream api changes on this
|
||||
assert hasattr(
|
||||
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
||||
), "transformers api changed for _get_unpad_data for flash attention"
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user