Compare commits
57 Commits
v0.10.0
...
update-vll
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe81d52882 | ||
|
|
1eaa4ed89d | ||
|
|
fe47392ed6 | ||
|
|
1032e22650 | ||
|
|
d68cc1e8ab | ||
|
|
21f1bf4805 | ||
|
|
de2c5ba103 | ||
|
|
9c0d7ee761 | ||
|
|
22d4a838dc | ||
|
|
a108e5db56 | ||
|
|
faff0cff41 | ||
|
|
759cefb741 | ||
|
|
69cd49a7aa | ||
|
|
5a961ecadf | ||
|
|
b37ddf9778 | ||
|
|
bf38e507fb | ||
|
|
a5946ff1f0 | ||
|
|
70ca1b2291 | ||
|
|
8ae5a2311b | ||
|
|
6383630155 | ||
|
|
f2b352f2e5 | ||
|
|
bf5928d0ee | ||
|
|
d1224db8f4 | ||
|
|
327b4e48e9 | ||
|
|
35fdbce102 | ||
|
|
cb811f8bf1 | ||
|
|
7563e1bd30 | ||
|
|
81893c775c | ||
|
|
a1a740608d | ||
|
|
ec15a7a691 | ||
|
|
0a7a216b60 | ||
|
|
d8280d45c1 | ||
|
|
24f2887e87 | ||
|
|
29289a4de9 | ||
|
|
a24957fa04 | ||
|
|
927bf530bc | ||
|
|
18954ba100 | ||
|
|
d8cf66edbd | ||
|
|
181cc3106b | ||
|
|
20106116da | ||
|
|
a27c4f8771 | ||
|
|
bb1109b81d | ||
|
|
8c69ec3a1e | ||
|
|
46675496a3 | ||
|
|
c6b5d35e5d | ||
|
|
12c826816d | ||
|
|
1d8f500709 | ||
|
|
0494359c6c | ||
|
|
26c39e1ca7 | ||
|
|
45adf1bfb9 | ||
|
|
eb3a57eb17 | ||
|
|
34da391391 | ||
|
|
0bb9077553 | ||
|
|
a85efffbef | ||
|
|
06a648263b | ||
|
|
9d5bfc127e | ||
|
|
da8f6c32b9 |
2
.bandit
2
.bandit
@@ -1,3 +1,3 @@
|
|||||||
[bandit]
|
[bandit]
|
||||||
exclude = tests
|
exclude = tests
|
||||||
skips = B101
|
skips = B101,B615
|
||||||
|
|||||||
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -5,11 +5,13 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- 'Dockerfile-base'
|
- 'docker/Dockerfile-base'
|
||||||
|
- 'docker/Dockerfile-uv-base'
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'Dockerfile-base'
|
- 'docker/Dockerfile-base'
|
||||||
|
- 'docker/Dockerfile-uv-base'
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter quartodoc
|
python3 -m pip install jupyter quartodoc
|
||||||
python3 -m pip install -e . --no-deps
|
python3 -m pip install -e .
|
||||||
- name: Build autodoc
|
- name: Build autodoc
|
||||||
run: quartodoc build
|
run: quartodoc build
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
|
|||||||
13
.github/workflows/main.yml
vendored
13
.github/workflows/main.yml
vendored
@@ -20,12 +20,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -88,8 +87,8 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
@@ -146,8 +145,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,11 +26,11 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
|
|||||||
6
.github/workflows/preview-docs.yml
vendored
6
.github/workflows/preview-docs.yml
vendored
@@ -8,7 +8,9 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- '**/*.md' # any Markdown file
|
- '**/*.md' # any Markdown file
|
||||||
- '**/*.qmd' # any Quarto file
|
- '**/*.qmd' # any Quarto file
|
||||||
- '_quarto.yaml'
|
- '_quarto.yml'
|
||||||
|
- docs/scripts/generate_config_docs.py
|
||||||
|
- src/axolotl/utils/schemas/**.py
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
checks: write
|
checks: write
|
||||||
@@ -38,7 +40,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter quartodoc
|
python3 -m pip install jupyter quartodoc
|
||||||
python3 -m pip install -e . --no-deps
|
python3 -m pip install -e .
|
||||||
|
|
||||||
- name: Build autodoc
|
- name: Build autodoc
|
||||||
run: quartodoc build
|
run: quartodoc build
|
||||||
|
|||||||
115
.github/workflows/tests-nightly.yml
vendored
115
.github/workflows/tests-nightly.yml
vendored
@@ -18,96 +18,9 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
preload-cache:
|
|
||||||
name: Preload HF cache
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python_version: ["3.11"]
|
|
||||||
pytorch_version: ["2.6.0"]
|
|
||||||
timeout-minutes: 20
|
|
||||||
|
|
||||||
env:
|
|
||||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check out repository code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Restore HF cache
|
|
||||||
id: hf-cache-restore
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python_version }}
|
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
|
||||||
run: |
|
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
|
||||||
run: |
|
|
||||||
pip3 install torch==${{ matrix.pytorch_version }}
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip3 show torch
|
|
||||||
pip3 install --no-build-isolation -U -e .
|
|
||||||
python scripts/unsloth_install.py | sh
|
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
|
||||||
run: |
|
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
|
||||||
run: |
|
|
||||||
axolotl --help
|
|
||||||
|
|
||||||
- name: Pre-Download dataset fixture
|
|
||||||
run: |
|
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
|
||||||
pytest -v tests/conftest.py
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
|
||||||
uses: codecov/codecov-action@v5
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
files: ./coverage.xml
|
|
||||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
|
||||||
fail_ci_if_error: false
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
- name: Save HF cache
|
|
||||||
id: hf-cache
|
|
||||||
uses: actions/cache/save@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [preload-cache]
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
@@ -120,14 +33,11 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore HF cache
|
- name: Restore Cache from S3
|
||||||
id: hf-cache-restore
|
id: hf-cache-restore-s3
|
||||||
uses: actions/cache/restore@v4
|
run: |
|
||||||
with:
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
path: |
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -168,10 +78,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
- name: Pre-Download dataset fixture
|
|
||||||
run: |
|
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
@@ -193,15 +99,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
|||||||
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -195,12 +195,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -247,8 +247,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
@@ -311,7 +311,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.2.0
|
rev: 7.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
@@ -27,7 +27,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.0
|
rev: v1.16.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.3
|
rev: 1.8.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
@@ -328,7 +328,7 @@ The following optimizers are supported:
|
|||||||
- Use `gradient_checkpointing: true` to reduce memory usage
|
- Use `gradient_checkpointing: true` to reduce memory usage
|
||||||
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
|
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
|
||||||
|
|
||||||
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html).
|
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config-reference.html).
|
||||||
|
|
||||||
### Errors:
|
### Errors:
|
||||||
|
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ include requirements.txt
|
|||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
include src/setuptools_axolotl_dynamic_dependencies.py
|
||||||
|
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||||
recursive-include axolotl *.py
|
recursive-include axolotl *.py
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -43,7 +43,7 @@ Features:
|
|||||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||||
|
|
||||||
@@ -59,6 +59,8 @@ Features:
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
#### Using pip
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
@@ -68,6 +70,13 @@ axolotl fetch examples
|
|||||||
axolotl fetch deepspeed_configs # OPTIONAL
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using Docker
|
||||||
|
|
||||||
|
Installing with Docker can be less error prone than installing in your own environment.
|
||||||
|
```bash
|
||||||
|
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||||
|
```
|
||||||
|
|
||||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
### Your First Fine-tune
|
### Your First Fine-tune
|
||||||
@@ -89,7 +98,7 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
|||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
|
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
|
||||||
- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples
|
- [Configuration Guide](https://docs.axolotl.ai/docs/config-reference.html) - Full configuration options and examples
|
||||||
- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources
|
- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources
|
||||||
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
|
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
|
||||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
|
pre-render: docs/scripts/generate_config_docs.py
|
||||||
|
|
||||||
quartodoc:
|
quartodoc:
|
||||||
dir: docs/api
|
dir: docs/api
|
||||||
@@ -235,7 +236,7 @@ website:
|
|||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/config-reference.qmd
|
||||||
- text: "API Reference"
|
- text: "API Reference"
|
||||||
href: docs/api
|
href: docs/api
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
|
|||||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||||
ENV HF_HOME="{{ HF_HOME }}"
|
ENV HF_HOME="{{ HF_HOME }}"
|
||||||
|
ENV AXOLOTL_DATASET_PROCESSES="8"
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
|||||||
@app.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=90 * 60, # 90 min
|
timeout=120 * 60, # 90 min
|
||||||
cpu=8.0,
|
cpu=8.0,
|
||||||
memory=131072,
|
memory=131072,
|
||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
@app.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=90 * 60,
|
timeout=120 * 60,
|
||||||
cpu=16.0,
|
cpu=16.0,
|
||||||
memory=131072 * N_GPUS,
|
memory=131072 * N_GPUS,
|
||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ df_args = {
|
|||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||||
|
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||||
|
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerfile_contents = df_template.render(**df_args)
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||||
pip3 install flash-attn==2.7.4.post1; \
|
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -34,7 +34,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
|
||||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
|
||||||
fi
|
|
||||||
|
|||||||
1
docs/.gitignore
vendored
1
docs/.gitignore
vendored
@@ -2,3 +2,4 @@
|
|||||||
_site/
|
_site/
|
||||||
/api/*.qmd
|
/api/*.qmd
|
||||||
/api/*.html
|
/api/*.html
|
||||||
|
config-reference.qmd
|
||||||
|
|||||||
801
docs/config.qmd
801
docs/config.qmd
@@ -1,801 +0,0 @@
|
|||||||
---
|
|
||||||
title: Config Reference
|
|
||||||
description: A complete list of all configuration options.
|
|
||||||
---
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
|
||||||
# This can also be a relative path to a model on disk
|
|
||||||
base_model: ./llama-7b-hf
|
|
||||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
|
||||||
base_model_ignore_patterns:
|
|
||||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
|
||||||
# You can set that here, or leave this empty to default to base_model
|
|
||||||
base_model_config: ./llama-7b-hf
|
|
||||||
# You can specify to choose a specific model revision from huggingface hub
|
|
||||||
revision_of_model:
|
|
||||||
# Optional tokenizer configuration path in case you want to use a different tokenizer
|
|
||||||
# than the one defined in the base model
|
|
||||||
tokenizer_config:
|
|
||||||
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
|
|
||||||
model_type: AutoModelForCausalLM
|
|
||||||
# Corresponding tokenizer for the model AutoTokenizer is a good choice
|
|
||||||
tokenizer_type: AutoTokenizer
|
|
||||||
# Trust remote code for untrusted source
|
|
||||||
trust_remote_code:
|
|
||||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
|
||||||
tokenizer_use_fast:
|
|
||||||
# Whether to use the legacy tokenizer setting, defaults to True
|
|
||||||
tokenizer_legacy:
|
|
||||||
# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.
|
|
||||||
tokenizer_use_mistral_common:
|
|
||||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
|
||||||
# This is reported to improve training speed on some models
|
|
||||||
resize_token_embeddings_to_32x:
|
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
|
||||||
shrink_embeddings:
|
|
||||||
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
|
|
||||||
embeddings_skip_upcast:
|
|
||||||
# Whether to load the model with randomly initialized weights. Useful for
|
|
||||||
# pre-training a model from scratch or debugging purposes.
|
|
||||||
random_init_weights:
|
|
||||||
|
|
||||||
# (Internal use only)
|
|
||||||
# Used to identify which the model is based on
|
|
||||||
is_falcon_derived_model:
|
|
||||||
is_llama_derived_model:
|
|
||||||
is_qwen_derived_model:
|
|
||||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
|
||||||
is_mistral_derived_model:
|
|
||||||
|
|
||||||
# optional overrides to the base model configuration
|
|
||||||
overrides_of_model_config:
|
|
||||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
|
||||||
rope_scaling:
|
|
||||||
type: # linear | dynamic
|
|
||||||
factor: # float
|
|
||||||
|
|
||||||
# optional overrides the base model loading from_pretrained
|
|
||||||
overrides_of_model_kwargs:
|
|
||||||
# use_cache: False
|
|
||||||
|
|
||||||
# optional overrides to the bnb 4bit quantization configuration
|
|
||||||
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
|
||||||
bnb_config_kwargs:
|
|
||||||
# These are default values
|
|
||||||
llm_int8_has_fp16_weight: false
|
|
||||||
bnb_4bit_quant_type: nf4
|
|
||||||
bnb_4bit_use_double_quant: true
|
|
||||||
|
|
||||||
# quantization aware training
|
|
||||||
qat:
|
|
||||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
|
||||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
|
||||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
|
||||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
|
||||||
|
|
||||||
# post-training quantization
|
|
||||||
quantization:
|
|
||||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
|
||||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
|
||||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
|
||||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
|
||||||
|
|
||||||
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
|
||||||
gptq: true
|
|
||||||
|
|
||||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
|
||||||
load_in_8bit: true
|
|
||||||
# Use bitsandbytes 4 bit
|
|
||||||
load_in_4bit:
|
|
||||||
|
|
||||||
# Use CUDA bf16
|
|
||||||
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
|
|
||||||
# Use CUDA fp16
|
|
||||||
fp16: true
|
|
||||||
# Use CUDA tf32
|
|
||||||
tf32: true # require >=ampere
|
|
||||||
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
|
|
||||||
|
|
||||||
# No AMP (automatic mixed precision)
|
|
||||||
bfloat16: true # require >=ampere
|
|
||||||
float16: true
|
|
||||||
|
|
||||||
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
|
||||||
gpu_memory_limit: 20GiB
|
|
||||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
|
||||||
lora_on_cpu: true
|
|
||||||
|
|
||||||
# List[str]. Add plugins to extend the pipeline.
|
|
||||||
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
|
|
||||||
# https://docs.axolotl.ai/docs/custom_integrations.html
|
|
||||||
plugins:
|
|
||||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
|
||||||
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
|
|
||||||
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
|
|
||||||
datasets:
|
|
||||||
# HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
|
|
||||||
- path: vicgalle/alpaca-gpt4
|
|
||||||
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
|
||||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
|
||||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
|
||||||
data_files: # Optional[str] path to source data files
|
|
||||||
|
|
||||||
shards: # Optional[int] split dataset into N pieces (use with shards_idx)
|
|
||||||
shards_idx: # Optional[int] = 0 the index of sharded dataset to use
|
|
||||||
|
|
||||||
preprocess_shards: # Optional[int] process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)
|
|
||||||
|
|
||||||
name: # Optional[str] name of dataset configuration to load
|
|
||||||
split: train # Optional[str] name of dataset split to load from
|
|
||||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
|
||||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
|
||||||
|
|
||||||
# Custom user instruction prompt
|
|
||||||
- path: repo
|
|
||||||
type:
|
|
||||||
# The below are defaults. only set what's needed if you use a different column name.
|
|
||||||
system_prompt: ""
|
|
||||||
system_format: "{system}"
|
|
||||||
field_system: system
|
|
||||||
field_instruction: instruction
|
|
||||||
field_input: input
|
|
||||||
field_output: output
|
|
||||||
|
|
||||||
# Customizable to be single line or multi-line
|
|
||||||
# Use {instruction}/{input} as key to be replaced
|
|
||||||
# 'format' can include {input}
|
|
||||||
format: |-
|
|
||||||
User: {instruction} {input}
|
|
||||||
Assistant:
|
|
||||||
# 'no_input_format' cannot include {input}
|
|
||||||
no_input_format: "{instruction} "
|
|
||||||
|
|
||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
|
||||||
field:
|
|
||||||
|
|
||||||
# Using chat template
|
|
||||||
- path: ...
|
|
||||||
# Set type to `chat_template` to use this strategy
|
|
||||||
type: chat_template
|
|
||||||
# Specify the name of the chat template to use
|
|
||||||
# The name of the chat template to use for training, following values are supported:
|
|
||||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
|
|
||||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
|
||||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
|
|
||||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
|
||||||
chat_template: tokenizer_default
|
|
||||||
|
|
||||||
# Custom jinja chat template. Used only if `chat_template: jinja` or empty.
|
|
||||||
chat_template_jinja:
|
|
||||||
|
|
||||||
# Key containing the messages (default: "messages")
|
|
||||||
field_messages: messages
|
|
||||||
|
|
||||||
# Key containing the tools (default: "tools")
|
|
||||||
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
|
||||||
field_tools: tools
|
|
||||||
|
|
||||||
# Key containing the system message (default: "system")
|
|
||||||
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
|
|
||||||
field_system: system
|
|
||||||
|
|
||||||
# Mapping of properties from the input dataset to the chat template.
|
|
||||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
|
||||||
# If a property exists in the template but not in this mapping, the system will attempt
|
|
||||||
# to load it directly from the message using the property name as the key.
|
|
||||||
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
|
||||||
# while 'value' is loaded and used as 'content' in the chat template.
|
|
||||||
message_property_mappings:
|
|
||||||
role: from
|
|
||||||
content: value
|
|
||||||
# ...
|
|
||||||
|
|
||||||
# Optional[Dict[str, List]]. Roles mapping in the messages.
|
|
||||||
# The format is {target_role: [source_roles]}. All source roles will be mapped to the target role.
|
|
||||||
# The default is:
|
|
||||||
roles:
|
|
||||||
user: ["human", "user"]
|
|
||||||
assistant: ["gpt", "assistant"]
|
|
||||||
system: ["system"]
|
|
||||||
tool: ["tool"]
|
|
||||||
|
|
||||||
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
|
|
||||||
# This does not drop the default system message from chat_template if it exists. If you wish to,
|
|
||||||
# we recommend using a custom jinja template with the default system message removed or
|
|
||||||
# adding a system turn with empty content.
|
|
||||||
drop_system_message:
|
|
||||||
|
|
||||||
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
|
|
||||||
# See example at `docs/dataset-formats/conversation.qmd`
|
|
||||||
split_thinking:
|
|
||||||
|
|
||||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
|
||||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
|
||||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
|
||||||
# Note: If the below 5 fields are empty, defaults to training only on the last message.
|
|
||||||
|
|
||||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
|
||||||
roles_to_train: ["assistant"] # default
|
|
||||||
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
|
|
||||||
# - all: train on all EOS tokens
|
|
||||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
|
||||||
# - last: train on the last EOS token in the conversation
|
|
||||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
|
||||||
train_on_eos: turn
|
|
||||||
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
|
|
||||||
# - all: train on all EOT tokens
|
|
||||||
# - turn: train on the EOT token at the end of each trainable turn
|
|
||||||
# - last: train on the last EOT token in the conversation
|
|
||||||
# If not specified, defaults to the value of train_on_eos for backward compatibility.
|
|
||||||
train_on_eot:
|
|
||||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
|
||||||
message_field_training: training
|
|
||||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
|
||||||
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
|
|
||||||
message_field_training_detail: train_detail
|
|
||||||
|
|
||||||
|
|
||||||
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
|
||||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
|
||||||
shuffle_merged_datasets: true
|
|
||||||
|
|
||||||
# Deduplicates datasets and test_datasets with identical entries.
|
|
||||||
dataset_exact_deduplication: true
|
|
||||||
|
|
||||||
# A list of one or more datasets to eval the model with.
|
|
||||||
# You can use either test_datasets, or val_set_size, but not both.
|
|
||||||
test_datasets:
|
|
||||||
- path: /workspace/data/eval.jsonl
|
|
||||||
ds_type: json
|
|
||||||
# You need to specify a split. For "json" datasets the default split is called "train".
|
|
||||||
split: train
|
|
||||||
type: completion
|
|
||||||
data_files:
|
|
||||||
- /workspace/data/eval.jsonl
|
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
|
||||||
rl:
|
|
||||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
|
||||||
|
|
||||||
# dpo
|
|
||||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
|
||||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
|
||||||
|
|
||||||
# orpo
|
|
||||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
|
||||||
|
|
||||||
# kto
|
|
||||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
|
||||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
|
||||||
|
|
||||||
# simpo
|
|
||||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
|
||||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
|
||||||
|
|
||||||
# grpo
|
|
||||||
trl:
|
|
||||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
|
||||||
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
|
||||||
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
|
||||||
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
|
||||||
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
|
||||||
|
|
||||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
|
||||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
|
||||||
|
|
||||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
|
||||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
|
||||||
|
|
||||||
num_generations: # Optional[int]. Number of generations to sample.
|
|
||||||
log_completions: # Optional[bool]. Whether to log completions.
|
|
||||||
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
|
|
||||||
|
|
||||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
|
||||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
|
||||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
|
||||||
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
|
|
||||||
|
|
||||||
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
|
|
||||||
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
|
|
||||||
top_k: # Optional[int]. Top-k sampling for the generation policy.
|
|
||||||
min_p: # Optional[float]. Minimum probability for the generation policy.
|
|
||||||
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
|
|
||||||
|
|
||||||
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
|
|
||||||
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
|
|
||||||
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
|
|
||||||
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
|
|
||||||
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
|
|
||||||
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
|
|
||||||
|
|
||||||
|
|
||||||
# reward modelling: `True` or `False`
|
|
||||||
reward_model:
|
|
||||||
|
|
||||||
# process reward modelling: `True` or `False`
|
|
||||||
process_reward_model:
|
|
||||||
|
|
||||||
# The name of the chat template to use for training, following values are supported:
|
|
||||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
|
||||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
|
||||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
|
|
||||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
|
||||||
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
|
|
||||||
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
|
|
||||||
chat_template: tokenizer_default
|
|
||||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
|
||||||
chat_template_jinja: null
|
|
||||||
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
|
|
||||||
# These tokens mark the boundaries between conversation turns.
|
|
||||||
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
|
|
||||||
# If not specified, defaults to just the model's eos_token.
|
|
||||||
# This is useful for templates that use multiple delimiter tokens.
|
|
||||||
eot_tokens:
|
|
||||||
# - "</s>"
|
|
||||||
# - "[/INST]"
|
|
||||||
# - "[/SYSTEM_PROMPT]"
|
|
||||||
# Changes the default system message
|
|
||||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
|
||||||
# subsequent training attempts load faster, relative path
|
|
||||||
dataset_prepared_path: data/last_run_prepared
|
|
||||||
# Push prepared dataset to hub
|
|
||||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
|
||||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
|
||||||
# if not set.
|
|
||||||
dataset_processes: # defaults to os.cpu_count() if not set
|
|
||||||
# Keep dataset in memory while preprocessing
|
|
||||||
# Only needed if cached dataset is taking too much storage
|
|
||||||
dataset_keep_in_memory:
|
|
||||||
# push checkpoints to hub
|
|
||||||
hub_model_id: # private repo path to push finetuned model
|
|
||||||
# how to push checkpoints to hub
|
|
||||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
|
||||||
hub_strategy:
|
|
||||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
|
||||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
|
||||||
hf_use_auth_token: # boolean
|
|
||||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
|
||||||
val_set_size: 0.04
|
|
||||||
# Num shards for whole dataset
|
|
||||||
dataset_shard_num:
|
|
||||||
# Index of shard to use for whole dataset
|
|
||||||
dataset_shard_idx:
|
|
||||||
|
|
||||||
# The maximum length of an input to train with, this should typically be less than 2048
|
|
||||||
# as most models have a token/context limit of 2048
|
|
||||||
sequence_len: 2048
|
|
||||||
# Pad inputs so each step uses constant sized buffers
|
|
||||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
|
||||||
pad_to_sequence_len:
|
|
||||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
|
||||||
sample_packing:
|
|
||||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
|
||||||
eval_sample_packing:
|
|
||||||
# You can set these packing optimizations AFTER starting a training at least once.
|
|
||||||
# The trainer will provide recommended values for these values.
|
|
||||||
sample_packing_eff_est:
|
|
||||||
total_num_tokens:
|
|
||||||
# Increasing the following values helps with packing, but usually only slightly (<%1.)
|
|
||||||
# The number of samples packed at a time.
|
|
||||||
sample_packing_group_size: 100000
|
|
||||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
|
||||||
sample_packing_bin_size: 200
|
|
||||||
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
|
|
||||||
|
|
||||||
# whether to concatenate samples during pretraining
|
|
||||||
pretraining_sample_concatenation:
|
|
||||||
|
|
||||||
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
|
|
||||||
|
|
||||||
# Use batch flattening for speedups when not using sample_packing
|
|
||||||
batch_flattening:
|
|
||||||
|
|
||||||
# Passed through to transformers when loading the model when launched without accelerate
|
|
||||||
# Use `sequential` when training w/ model parallelism to limit memory
|
|
||||||
device_map:
|
|
||||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
|
||||||
max_memory:
|
|
||||||
|
|
||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
|
||||||
adapter: lora
|
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
|
||||||
# This means after training, if you want to test the model, you should set this to the value of `output_dir`.
|
|
||||||
# Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`.
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
# LoRA hyperparameters
|
|
||||||
# For more details about the following options, see:
|
|
||||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
|
||||||
lora_r: 8
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_modules:
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
# - k_proj
|
|
||||||
# - o_proj
|
|
||||||
# - gate_proj
|
|
||||||
# - down_proj
|
|
||||||
# - up_proj
|
|
||||||
lora_target_linear: # If true, will target all linear modules
|
|
||||||
|
|
||||||
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
|
|
||||||
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
|
|
||||||
peft_layers_to_transform:
|
|
||||||
|
|
||||||
# Optional[bool]. Whether to use DoRA.
|
|
||||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
|
|
||||||
peft_use_dora:
|
|
||||||
|
|
||||||
# Optional[bool]. Whether to use RSLoRA.
|
|
||||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
|
|
||||||
peft_use_rslora:
|
|
||||||
|
|
||||||
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
|
|
||||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
|
|
||||||
peft_layer_replication:
|
|
||||||
|
|
||||||
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
|
|
||||||
# How to initialize LoRA weights. Default to True which is MS original implementation.
|
|
||||||
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
|
|
||||||
peft_init_lora_weights:
|
|
||||||
|
|
||||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
|
||||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
|
||||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
|
||||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
|
||||||
lora_modules_to_save:
|
|
||||||
# - embed_tokens
|
|
||||||
# - lm_head
|
|
||||||
|
|
||||||
lora_fan_in_fan_out: false
|
|
||||||
|
|
||||||
# Apply custom LoRA autograd functions and activation function Triton kernels for
|
|
||||||
# speed and memory savings
|
|
||||||
# See: https://docs.axolotl.ai/docs/lora_optims.html
|
|
||||||
lora_mlp_kernel: true
|
|
||||||
lora_qkv_kernel: true
|
|
||||||
lora_o_kernel: true
|
|
||||||
|
|
||||||
# LoRA+ hyperparameters
|
|
||||||
# For more details about the following options, see:
|
|
||||||
# https://arxiv.org/abs/2402.12354 and `src/axolotl/core/train_builder.py`
|
|
||||||
loraplus_lr_ratio: # loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4.
|
|
||||||
loraplus_lr_embedding: # loraplus learning rate for lora embedding layers. Default value is 1e-6.
|
|
||||||
|
|
||||||
peft:
|
|
||||||
# Configuration options for loftq initialization for LoRA
|
|
||||||
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
|
|
||||||
loftq_config:
|
|
||||||
loftq_bits: # typically 4 bits
|
|
||||||
|
|
||||||
# ReLoRA configuration
|
|
||||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
|
||||||
relora_steps: # Number of steps per ReLoRA restart
|
|
||||||
relora_warmup_steps: # Number of per-restart warmup steps
|
|
||||||
relora_anneal_steps: # Number of anneal steps for each relora cycle
|
|
||||||
relora_prune_ratio: # threshold for optimizer magnitude when pruning
|
|
||||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
|
||||||
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
|
||||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
|
||||||
wandb_project: # Your wandb project name
|
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name: # Set the name of your wandb run
|
|
||||||
wandb_run_id: # Set the ID of your wandb run
|
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
|
||||||
|
|
||||||
# mlflow configuration if you're using it
|
|
||||||
mlflow_tracking_uri: # URI to mlflow
|
|
||||||
mlflow_experiment_name: # Your experiment name
|
|
||||||
mlflow_run_name: # Your run name
|
|
||||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
|
||||||
|
|
||||||
# Comet configuration if you're using it
|
|
||||||
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
|
||||||
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
|
||||||
use_comet: # Enable or disable Comet integration.
|
|
||||||
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
|
||||||
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
|
||||||
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
|
||||||
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
|
||||||
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
|
||||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
|
||||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
|
||||||
|
|
||||||
# Tensorboard
|
|
||||||
use_tensorboard: # Optional[bool]
|
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
|
||||||
output_dir: ./completed-model
|
|
||||||
|
|
||||||
# Whether to use torch.compile and which backend to use
|
|
||||||
# setting to `auto` will enable torch compile when torch>=2.5.1
|
|
||||||
torch_compile: # Optional[Union[Literal["auto"], bool]]
|
|
||||||
torch_compile_backend: # Optional[str]
|
|
||||||
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
|
|
||||||
|
|
||||||
# Training hyperparameters
|
|
||||||
|
|
||||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
|
||||||
# Batch size per gpu = micro_batch_size * gradient_accumulation_steps
|
|
||||||
micro_batch_size: 2
|
|
||||||
eval_batch_size:
|
|
||||||
num_epochs: 4
|
|
||||||
warmup_steps: 100 # cannot use with warmup_ratio
|
|
||||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
|
||||||
learning_rate: 0.00003
|
|
||||||
lr_quadratic_warmup:
|
|
||||||
logging_steps:
|
|
||||||
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
|
|
||||||
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
|
||||||
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
|
|
||||||
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
|
|
||||||
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
|
||||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
|
||||||
save_total_limit: # Checkpoints saved at a time
|
|
||||||
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
|
|
||||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
|
||||||
# if both are set, num_epochs will not be guaranteed.
|
|
||||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
|
||||||
max_steps:
|
|
||||||
|
|
||||||
# bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time.
|
|
||||||
include_tokens_per_second: # Optional[bool]
|
|
||||||
|
|
||||||
# whether to find batch size that fits in memory. Passed to underlying transformers Trainer
|
|
||||||
auto_find_batch_size: # Optional[bool]
|
|
||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
|
||||||
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
|
||||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
|
||||||
|
|
||||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
|
||||||
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
|
|
||||||
# snapshots can be visualized @ https://pytorch.org/memory_viz
|
|
||||||
|
|
||||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
|
||||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
|
||||||
|
|
||||||
# Save model as safetensors (require safetensors package). Default True
|
|
||||||
save_safetensors:
|
|
||||||
|
|
||||||
# Whether to mask out or include the human's prompt from the training labels
|
|
||||||
train_on_inputs: false
|
|
||||||
# Group similarly sized data to minimize padding.
|
|
||||||
# May be slower to start, as it must download and sort the entire dataset.
|
|
||||||
# Note that training loss may have an oscillating pattern with this enabled.
|
|
||||||
group_by_length: false
|
|
||||||
|
|
||||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
|
||||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
|
||||||
gradient_checkpointing: false
|
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
|
||||||
# gradient_checkpointing_kwargs:
|
|
||||||
# use_reentrant: true
|
|
||||||
|
|
||||||
# Stop training after this many evaluation losses have increased in a row
|
|
||||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
|
||||||
early_stopping_patience: 3
|
|
||||||
|
|
||||||
# Specify a scheduler and kwargs to use with the optimizer
|
|
||||||
# Valid values are driven by the Transformers SchedulerType class, see:
|
|
||||||
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
|
|
||||||
# Valid values include
|
|
||||||
# - 'linear'
|
|
||||||
# - 'cosine' (default)
|
|
||||||
# - 'cosine_with_restarts'
|
|
||||||
# - 'polynomial'
|
|
||||||
# - 'constant'
|
|
||||||
# - 'constant_with_warmup'
|
|
||||||
# - 'inverse_sqrt'
|
|
||||||
# - 'reduce_lr_on_plateau'
|
|
||||||
# - 'cosine_with_min_lr'
|
|
||||||
# - 'warmup_stable_decay'
|
|
||||||
|
|
||||||
# Additional schedulers include:
|
|
||||||
# - 'one_cycle'
|
|
||||||
# - 'rex'
|
|
||||||
lr_scheduler:
|
|
||||||
lr_scheduler_kwargs:
|
|
||||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
|
||||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
|
||||||
|
|
||||||
# For one_cycle optim
|
|
||||||
lr_div_factor: # Learning rate div factor
|
|
||||||
|
|
||||||
# Specify optimizer
|
|
||||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
|
||||||
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
|
||||||
#
|
|
||||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
|
||||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
|
||||||
# in the examples/ for your model and fine-tuning use case.
|
|
||||||
#
|
|
||||||
# Valid values for 'optimizer' include:
|
|
||||||
# - adamw_torch
|
|
||||||
# - adamw_torch_fused (default)
|
|
||||||
# - adamw_torch_xla
|
|
||||||
# - adamw_torch_npu_fused
|
|
||||||
# - adamw_apex_fused
|
|
||||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
|
||||||
# - adafactor
|
|
||||||
# - adamw_anyprecision
|
|
||||||
# - adamw_torch_4bit
|
|
||||||
# - ademamix
|
|
||||||
# - sgd
|
|
||||||
# - adagrad
|
|
||||||
# - adamw_bnb_8bit
|
|
||||||
# - adamw_8bit # alias for adamw_bnb_8bit
|
|
||||||
# - ademamix_8bit
|
|
||||||
# - lion_8bit
|
|
||||||
# - lion_32bit
|
|
||||||
# - paged_adamw_32bit
|
|
||||||
# - paged_adamw_8bit
|
|
||||||
# - paged_ademamix_32bit
|
|
||||||
# - paged_ademamix_8bit
|
|
||||||
# - paged_lion_32bit
|
|
||||||
# - paged_lion_8bit
|
|
||||||
# - rmsprop
|
|
||||||
# - rmsprop_bnb
|
|
||||||
# - rmsprop_bnb_8bit
|
|
||||||
# - rmsprop_bnb_32bit
|
|
||||||
# - galore_adamw
|
|
||||||
# - galore_adamw_8bit
|
|
||||||
# - galore_adafactor
|
|
||||||
# - galore_adamw_layerwise
|
|
||||||
# - galore_adamw_8bit_layerwise
|
|
||||||
# - galore_adafactor_layerwise
|
|
||||||
# - lomo
|
|
||||||
# - adalomo
|
|
||||||
# - grokadamw
|
|
||||||
# - schedule_free_adamw
|
|
||||||
# - schedule_free_sgd
|
|
||||||
# - apollo_adamw
|
|
||||||
# - apollo_adamw_layerwise
|
|
||||||
#
|
|
||||||
# Additional custom optimizers include:
|
|
||||||
# - optimi_adamw
|
|
||||||
# - ao_adamw_8bit
|
|
||||||
# - ao_adamw_fp8
|
|
||||||
# - came_pytorch
|
|
||||||
optimizer:
|
|
||||||
# Dictionary of arguments to pass to the optimizer
|
|
||||||
optim_args:
|
|
||||||
# For Galore Optimizers the following optim_args are available
|
|
||||||
# rank: # type: int
|
|
||||||
# update_proj_gap # type: int
|
|
||||||
# scale # type: float
|
|
||||||
# proj_type: # type: str, default = std
|
|
||||||
|
|
||||||
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
|
||||||
optim_target_modules:
|
|
||||||
# - self_attn # for llama
|
|
||||||
# - mlp
|
|
||||||
|
|
||||||
# Specify weight decay
|
|
||||||
weight_decay:
|
|
||||||
# adamw hyperparams
|
|
||||||
adam_beta1:
|
|
||||||
adam_beta2:
|
|
||||||
adam_beta3: # only used for CAME Optimizer
|
|
||||||
adam_epsilon:
|
|
||||||
adam_epsilon2: # only used for CAME Optimizer
|
|
||||||
# Gradient clipping max norm
|
|
||||||
max_grad_norm:
|
|
||||||
|
|
||||||
# Augmentation techniques
|
|
||||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
|
||||||
# currently only supported on Llama and Mistral
|
|
||||||
neftune_noise_alpha:
|
|
||||||
|
|
||||||
# Optional[bool]. Whether to bettertransformers
|
|
||||||
flash_optimum:
|
|
||||||
|
|
||||||
# Note: Only one of the following attention patches can be used at a time.
|
|
||||||
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
|
|
||||||
|
|
||||||
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
|
||||||
xformers_attention:
|
|
||||||
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
|
||||||
flash_attention:
|
|
||||||
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
|
|
||||||
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
|
|
||||||
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
|
|
||||||
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
|
|
||||||
# Optional[bool]. Whether to use scaled-dot-product attention
|
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
|
||||||
sdp_attention:
|
|
||||||
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
|
||||||
s2_attention:
|
|
||||||
|
|
||||||
# Optional[bool]. Whether to use low_cpu_mem_usage
|
|
||||||
low_cpu_mem_usage:
|
|
||||||
# Optional[str]. Resume from a specific checkpoint dir
|
|
||||||
resume_from_checkpoint:
|
|
||||||
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
|
||||||
# Be careful with this being turned on between different models.
|
|
||||||
auto_resume_from_checkpoints: false
|
|
||||||
|
|
||||||
## Multimodal section
|
|
||||||
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
|
||||||
# Will read from model/processor config if not set.
|
|
||||||
image_size:
|
|
||||||
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
|
||||||
image_resize_algorithm: 'bilinear'
|
|
||||||
## End of multimodal section
|
|
||||||
|
|
||||||
# Don't mess with this, it's here for accelerate and torchrun
|
|
||||||
local_rank:
|
|
||||||
|
|
||||||
# Add or change special tokens.
|
|
||||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
|
||||||
special_tokens:
|
|
||||||
# bos_token: "<s>"
|
|
||||||
# eos_token: "</s>"
|
|
||||||
# unk_token: "<unk>"
|
|
||||||
# pad_token: "[PAD]"
|
|
||||||
|
|
||||||
# Optional[list[str]]. Add extra tokens to the tokenizer.
|
|
||||||
tokens:
|
|
||||||
# - "<|startoftext|>"
|
|
||||||
# - "<|endoftext|>"
|
|
||||||
|
|
||||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
|
||||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
|
||||||
# Can be checked if they exist in tokenizer.json added_tokens.
|
|
||||||
added_tokens_overrides: # Dict[int, str]
|
|
||||||
# 128041: "<|im_start|>"
|
|
||||||
# 128042: "<|im_end|>"
|
|
||||||
|
|
||||||
# FSDP
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
|
|
||||||
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
|
|
||||||
deepspeed:
|
|
||||||
|
|
||||||
# Advanced DDP Arguments
|
|
||||||
ddp_timeout:
|
|
||||||
ddp_bucket_cap_mb:
|
|
||||||
ddp_broadcast_buffers:
|
|
||||||
|
|
||||||
# Sequence parallelism
|
|
||||||
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
|
||||||
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
|
||||||
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
|
||||||
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
|
||||||
# See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details.
|
|
||||||
sequence_parallel_degree:
|
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
|
||||||
# Must evenly divide the number of KV heads in your model.
|
|
||||||
heads_k_stride: 1
|
|
||||||
# One of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to "varlen_llama3"
|
|
||||||
# in the sample packing case, and "batch_ring" in the non-sample packing case.
|
|
||||||
ring_attn_func:
|
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
|
||||||
torchdistx_path:
|
|
||||||
|
|
||||||
# Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
|
|
||||||
pretraining_dataset:
|
|
||||||
|
|
||||||
# Debug mode
|
|
||||||
debug:
|
|
||||||
|
|
||||||
# Seed
|
|
||||||
seed:
|
|
||||||
|
|
||||||
# Allow overwrite yml config using from cli
|
|
||||||
strict:
|
|
||||||
```
|
|
||||||
@@ -7,6 +7,7 @@ toc-depth: 3
|
|||||||
```{python}
|
```{python}
|
||||||
#| echo: false
|
#| echo: false
|
||||||
|
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
def process_readme(integration_name):
|
def process_readme(integration_name):
|
||||||
@@ -53,6 +54,24 @@ sections = [
|
|||||||
("LLMCompressor", "llm_compressor")
|
("LLMCompressor", "llm_compressor")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
for folder_name in os.listdir("../src/axolotl/integrations/"):
|
||||||
|
if folder_name in [path for name, path in sections]:
|
||||||
|
# skip if already in sections
|
||||||
|
continue
|
||||||
|
if os.path.exists(f"../src/axolotl/integrations/{folder_name}/README.md"):
|
||||||
|
# grab the first heading in README.md as the section name
|
||||||
|
with open(f"../src/axolotl/integrations/{folder_name}/README.md", "r") as f:
|
||||||
|
txt = f.read()
|
||||||
|
matches = re.search(r'^# (.*)\n?', txt, flags=re.MULTILINE)
|
||||||
|
if matches:
|
||||||
|
name = matches.group(1)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
sections.append((name, folder_name))
|
||||||
|
|
||||||
|
# sort sections by name
|
||||||
|
sections = sorted(sections, key=lambda x: x[0])
|
||||||
|
|
||||||
for section_name, folder_name in sections:
|
for section_name, folder_name in sections:
|
||||||
print(print_section(section_name, folder_name))
|
print(print_section(section_name, folder_name))
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ order: 3
|
|||||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
```{.json filename="data.jsonl"}
|
||||||
{"conversations": [{"role": "...", "content": "..."}]}
|
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
||||||
```
|
```
|
||||||
|
|
||||||
See [configs](../config.qmd) for full configs and supported templates.
|
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||||
|
|
||||||
### Migrating from sharegpt
|
### Migrating from sharegpt
|
||||||
|
|
||||||
@@ -130,13 +130,13 @@ datasets:
|
|||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
See [config documentation](../config-reference.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-note}
|
::: {.callout-note}
|
||||||
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
|
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
|
||||||
|
|
||||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config-reference.qmd) for more details.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||||
|
|||||||
@@ -186,4 +186,4 @@ datasets:
|
|||||||
no_input_format: "[INST] {instruction} [/INST]"
|
no_input_format: "[INST] {instruction} [/INST]"
|
||||||
```
|
```
|
||||||
|
|
||||||
See full config options under [here](../config.qmd).
|
See full config options under [here](../config-reference.qmd).
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ This matches the API of [`datasets.load_dataset`](https://github.com/huggingface
|
|||||||
|
|
||||||
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
|
For HuggingFace's guide to load different dataset types, see [here](https://huggingface.co/docs/datasets/loading).
|
||||||
|
|
||||||
For full details on the config, see [config.qmd](config.qmd).
|
For full details on the config, see [config-reference.qmd](config-reference.qmd).
|
||||||
|
|
||||||
::: {.callout-note}
|
::: {.callout-note}
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ format:
|
|||||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
@@ -34,6 +34,7 @@ Tags examples:
|
|||||||
|
|
||||||
- `main-base-py3.11-cu128-2.7.1`
|
- `main-base-py3.11-cu128-2.7.1`
|
||||||
- `main-base-py3.11-cu126-2.7.1`
|
- `main-base-py3.11-cu126-2.7.1`
|
||||||
|
- `main-base-py3.11-cu126-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.6.0`
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.5.1`
|
- `main-base-py3.11-cu124-2.5.1`
|
||||||
|
|
||||||
@@ -73,13 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu126-2.7.0`
|
- `main-py3.11-cu128-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.6.0`
|
||||||
- `main-py3.11-cu124-2.6.0`
|
- `main-py3.11-cu124-2.6.0`
|
||||||
- `main-py3.11-cu124-2.5.1`
|
- `main-py3.11-cu124-2.5.1`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu124-2.5.1`
|
- `main-20250303-py3.11-cu124-2.5.1`
|
||||||
- `0.9.2`
|
- `0.10.1`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
16
docs/faq.qmd
16
docs/faq.qmd
@@ -9,11 +9,11 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
> A: Usually an issue with the GPUs communicating with each other. See the [NCCL doc](nccl.qmd)
|
||||||
|
|
||||||
**Q: Exitcode -9**
|
**Q: exitcode: -9**
|
||||||
|
|
||||||
> A: This usually happens when you run out of system RAM.
|
> A: This usually happens when you run out of system RAM.
|
||||||
|
|
||||||
**Q: Exitcode -7 while using deepspeed**
|
**Q: exitcode: -7 while using deepspeed**
|
||||||
|
|
||||||
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
> A: Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||||
|
|
||||||
@@ -51,6 +51,18 @@ description: Frequently asked questions
|
|||||||
> pad_token: "..."
|
> pad_token: "..."
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
|
||||||
|
|
||||||
|
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
|
||||||
|
|
||||||
|
**Q: vLLM is not working with Axolotl**
|
||||||
|
|
||||||
|
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
||||||
|
|
||||||
|
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||||
|
|
||||||
|
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
|
|||||||
> See the [example config](#example-config) file in addition to reading these instructions.
|
> See the [example config](#example-config) file in addition to reading these instructions.
|
||||||
|
|
||||||
1. Set `adapter: qlora` in your axolotl config file.
|
1. Set `adapter: qlora` in your axolotl config file.
|
||||||
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
|
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
|
||||||
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
|
||||||
|
|
||||||
## Example Config
|
## Example Config
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ output_dir: ./outputs/lora-out
|
|||||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
See our [Config options](config.qmd) for more details.
|
See our [config options](config-reference.qmd) for more details.
|
||||||
|
|
||||||
### Training {#sec-training}
|
### Training {#sec-training}
|
||||||
|
|
||||||
@@ -179,7 +179,7 @@ Now that you have the basics, you might want to:
|
|||||||
|
|
||||||
Check our other guides for details on these topics:
|
Check our other guides for details on these topics:
|
||||||
|
|
||||||
- [Configuration Guide](config.qmd) - Full configuration options
|
- [Configuration Guide](config-reference.qmd) - Full configuration options
|
||||||
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
|
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
|
||||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||||
- [Multi-GPU Training](multi-gpu.qmd)
|
- [Multi-GPU Training](multi-gpu.qmd)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
## Requirements {#sec-requirements}
|
## Requirements {#sec-requirements}
|
||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.10
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.5.1
|
- PyTorch ≥2.5.1
|
||||||
|
|
||||||
## Installation Methods {#sec-installation-methods}
|
## Installation Methods {#sec-installation-methods}
|
||||||
@@ -153,7 +153,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
|||||||
|
|
||||||
### Conda/Pip venv {#sec-conda}
|
### Conda/Pip venv {#sec-conda}
|
||||||
|
|
||||||
1. Install Python ≥3.10
|
1. Install Python ≥3.11
|
||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||||
3. Install Axolotl:
|
3. Install Axolotl:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ output_dir: # The path to the output directory.
|
|||||||
|
|
||||||
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
|
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
|
||||||
|
|
||||||
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
|
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.qmd) - you can do this by using the existing QAT configuration file which
|
||||||
you used to train the model:
|
you used to train the model:
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
752
docs/scripts/generate_config_docs.py
Normal file
752
docs/scripts/generate_config_docs.py
Normal file
@@ -0,0 +1,752 @@
|
|||||||
|
# type: ignore
|
||||||
|
|
||||||
|
"""
|
||||||
|
Quarto documentation generation from Pydantic models. Uses Pydantic model source code
|
||||||
|
to automatically group fields, including inherited fields from parent classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
|
import types
|
||||||
|
import typing
|
||||||
|
from typing import Any, FrozenSet, Type, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
|
class QuartoGenerator:
|
||||||
|
"""Generate Quarto documentation from Pydantic models."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._class_fields_cache = {}
|
||||||
|
self._inheritance_map_cache = {}
|
||||||
|
self._nested_models_cache = {}
|
||||||
|
|
||||||
|
def _get_direct_fields(self, cls: Type[BaseModel]) -> FrozenSet[str]:
|
||||||
|
"""Get fields defined directly in a single class (not inherited)."""
|
||||||
|
if cls in self._class_fields_cache:
|
||||||
|
return self._class_fields_cache[cls]
|
||||||
|
|
||||||
|
fields = set()
|
||||||
|
|
||||||
|
# Get annotated fields
|
||||||
|
if hasattr(cls, "__annotations__"):
|
||||||
|
fields.update(cls.__annotations__.keys())
|
||||||
|
|
||||||
|
# Filter out private/special methods
|
||||||
|
fields = {f for f in fields if not f.startswith("_")}
|
||||||
|
|
||||||
|
result = frozenset(fields)
|
||||||
|
self._class_fields_cache[cls] = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _is_pydantic_model(self, type_obj) -> bool:
|
||||||
|
"""Check if a type is a Pydantic BaseModel."""
|
||||||
|
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
|
||||||
|
|
||||||
|
# pylint: disable=too-many-return-statements
|
||||||
|
def _extract_nested_type(self, field_type) -> Any:
|
||||||
|
"""Extract the actual type from complex type annotations."""
|
||||||
|
# Handle Annotated types (Python 3.9+)
|
||||||
|
if hasattr(typing, "get_origin") and hasattr(typing, "get_args"):
|
||||||
|
origin = typing.get_origin(field_type)
|
||||||
|
args = typing.get_args(field_type)
|
||||||
|
|
||||||
|
if origin is not None:
|
||||||
|
# Handle Annotated[SomeType, ...] - extract the first argument
|
||||||
|
if hasattr(typing, "Annotated") and origin is typing.Annotated:
|
||||||
|
if args:
|
||||||
|
return self._extract_nested_type(
|
||||||
|
args[0]
|
||||||
|
) # Recursively process the actual type
|
||||||
|
|
||||||
|
# Handle list[SomeType], List[SomeType], etc.
|
||||||
|
elif origin in (list, typing.List):
|
||||||
|
if args:
|
||||||
|
return self._extract_nested_type(
|
||||||
|
args[0]
|
||||||
|
) # Extract element type
|
||||||
|
|
||||||
|
# Handle Union types (including | syntax)
|
||||||
|
elif origin is typing.Union:
|
||||||
|
# Get non-None types from the Union
|
||||||
|
non_none_types = [arg for arg in args if arg is not type(None)]
|
||||||
|
if len(non_none_types) >= 1:
|
||||||
|
# Prioritize Pydantic models over primitive types
|
||||||
|
pydantic_models = [
|
||||||
|
arg
|
||||||
|
for arg in non_none_types
|
||||||
|
if self._is_pydantic_model(arg)
|
||||||
|
]
|
||||||
|
if pydantic_models:
|
||||||
|
# Return the first Pydantic model found
|
||||||
|
return self._extract_nested_type(pydantic_models[0])
|
||||||
|
|
||||||
|
# No Pydantic models, return the first non-None type
|
||||||
|
return self._extract_nested_type(non_none_types[0])
|
||||||
|
|
||||||
|
# Handle new Python 3.10+ union syntax (PeftConfig | None)
|
||||||
|
if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType:
|
||||||
|
# Get non-None types from the Union
|
||||||
|
non_none_types = [
|
||||||
|
arg for arg in field_type.__args__ if arg is not type(None)
|
||||||
|
]
|
||||||
|
if len(non_none_types) >= 1:
|
||||||
|
# Prioritize Pydantic models over primitive types
|
||||||
|
pydantic_models = [
|
||||||
|
arg for arg in non_none_types if self._is_pydantic_model(arg)
|
||||||
|
]
|
||||||
|
if pydantic_models:
|
||||||
|
return self._extract_nested_type(pydantic_models[0])
|
||||||
|
return self._extract_nested_type(non_none_types[0])
|
||||||
|
|
||||||
|
# Handle old typing.Union syntax (fallback)
|
||||||
|
if hasattr(field_type, "__origin__"):
|
||||||
|
if field_type.__origin__ is Union:
|
||||||
|
# Get non-None types from the Union
|
||||||
|
non_none_types = [
|
||||||
|
arg for arg in field_type.__args__ if arg is not type(None)
|
||||||
|
]
|
||||||
|
if len(non_none_types) >= 1:
|
||||||
|
# Prioritize Pydantic models over primitive types
|
||||||
|
pydantic_models = [
|
||||||
|
arg for arg in non_none_types if self._is_pydantic_model(arg)
|
||||||
|
]
|
||||||
|
if pydantic_models:
|
||||||
|
return self._extract_nested_type(pydantic_models[0])
|
||||||
|
return self._extract_nested_type(non_none_types[0])
|
||||||
|
# Handle other generic types like dict[str, Any], etc.
|
||||||
|
elif hasattr(field_type, "__args__"):
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
# pylint: disable=too-many-return-statements
|
||||||
|
def _extract_all_pydantic_models_from_type(
|
||||||
|
self, field_type
|
||||||
|
) -> list[type[BaseModel]]:
|
||||||
|
"""Extract all Pydantic models from a type annotation, including from Unions."""
|
||||||
|
models = []
|
||||||
|
|
||||||
|
if field_type is None:
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Handle Annotated types
|
||||||
|
if hasattr(typing, "get_origin") and hasattr(typing, "get_args"):
|
||||||
|
origin = typing.get_origin(field_type)
|
||||||
|
args = typing.get_args(field_type)
|
||||||
|
|
||||||
|
if origin is not None:
|
||||||
|
# Handle Annotated[SomeType, ...] - extract from the first argument
|
||||||
|
if hasattr(typing, "Annotated") and origin is typing.Annotated:
|
||||||
|
if args:
|
||||||
|
models.extend(
|
||||||
|
self._extract_all_pydantic_models_from_type(args[0])
|
||||||
|
)
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Handle list[SomeType], List[SomeType], etc.
|
||||||
|
if origin in (list, typing.List):
|
||||||
|
if args:
|
||||||
|
models.extend(
|
||||||
|
self._extract_all_pydantic_models_from_type(args[0])
|
||||||
|
)
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Handle Union types
|
||||||
|
if origin is typing.Union:
|
||||||
|
for arg in args:
|
||||||
|
if arg is not type(None): # Skip None type
|
||||||
|
models.extend(
|
||||||
|
self._extract_all_pydantic_models_from_type(arg)
|
||||||
|
)
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Handle new Python 3.10+ union syntax
|
||||||
|
if hasattr(field_type, "__class__") and field_type.__class__ is types.UnionType:
|
||||||
|
for arg in field_type.__args__:
|
||||||
|
if arg is not type(None): # Skip None type
|
||||||
|
models.extend(self._extract_all_pydantic_models_from_type(arg))
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Handle old typing.Union syntax (fallback)
|
||||||
|
if hasattr(field_type, "__origin__") and field_type.__origin__ is Union:
|
||||||
|
for arg in field_type.__args__:
|
||||||
|
if arg is not type(None): # Skip None type
|
||||||
|
models.extend(self._extract_all_pydantic_models_from_type(arg))
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Check if this type itself is a Pydantic model
|
||||||
|
if self._is_pydantic_model(field_type):
|
||||||
|
models.append(field_type)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def _get_nested_models(
|
||||||
|
self, model_class: type[BaseModel], visited=None
|
||||||
|
) -> dict[str, type[BaseModel]]:
|
||||||
|
"""Get all nested Pydantic models from a model class."""
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
|
||||||
|
# Avoid infinite recursion
|
||||||
|
if model_class in visited:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if model_class in self._nested_models_cache:
|
||||||
|
return self._nested_models_cache[model_class]
|
||||||
|
|
||||||
|
visited.add(model_class)
|
||||||
|
nested_models = {}
|
||||||
|
|
||||||
|
# Check all fields in the model
|
||||||
|
for field_info in model_class.model_fields.values():
|
||||||
|
field_type = self._extract_nested_type(field_info.annotation)
|
||||||
|
|
||||||
|
if self._is_pydantic_model(field_type):
|
||||||
|
nested_models[field_type.__name__] = field_type
|
||||||
|
# Recursively get nested models from this nested model
|
||||||
|
deeper_nested = self._get_nested_models(field_type, visited.copy())
|
||||||
|
nested_models.update(deeper_nested)
|
||||||
|
|
||||||
|
self._nested_models_cache[model_class] = nested_models
|
||||||
|
return nested_models
|
||||||
|
|
||||||
|
def _build_inheritance_map(self, child_class: Type[BaseModel]):
|
||||||
|
"""Build inheritance map for a class and all its parents."""
|
||||||
|
if child_class in self._inheritance_map_cache:
|
||||||
|
return self._inheritance_map_cache[child_class]
|
||||||
|
|
||||||
|
inheritance_map = {}
|
||||||
|
|
||||||
|
# Get MRO and filter out BaseModel and object
|
||||||
|
mro_classes = [
|
||||||
|
cls
|
||||||
|
for cls in child_class.__mro__
|
||||||
|
if cls not in (BaseModel, object) and hasattr(cls, "__annotations__")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process each class in the MRO
|
||||||
|
for cls in mro_classes:
|
||||||
|
inheritance_map[cls] = self._get_direct_fields(cls)
|
||||||
|
|
||||||
|
self._inheritance_map_cache[child_class] = inheritance_map
|
||||||
|
return inheritance_map
|
||||||
|
|
||||||
|
def _wrap_comment(self, text: str, width: int = 88) -> list[str]:
|
||||||
|
"""Wrap a comment to specified width, accounting for '# ' prefix."""
|
||||||
|
if not text.strip():
|
||||||
|
return ["#"]
|
||||||
|
|
||||||
|
# Account for "# " prefix (2 characters)
|
||||||
|
content_width = width - 2
|
||||||
|
wrapped_lines = textwrap.wrap(text, width=content_width)
|
||||||
|
return [f"# {line}" for line in wrapped_lines]
|
||||||
|
|
||||||
|
def _extract_type_from_source(
|
||||||
|
self, model_class: type[BaseModel], field_name: str
|
||||||
|
) -> str:
|
||||||
|
"""Extract the actual type annotation text from source code, checking inheritance chain."""
|
||||||
|
# Use inheritance map to check classes efficiently
|
||||||
|
inheritance_map = self._build_inheritance_map(model_class)
|
||||||
|
|
||||||
|
# Check classes in MRO order
|
||||||
|
for cls in model_class.__mro__:
|
||||||
|
if cls in inheritance_map and field_name in inheritance_map[cls]:
|
||||||
|
type_annotation = self._get_type_from_class_source(cls, field_name)
|
||||||
|
if type_annotation != "unknown":
|
||||||
|
return type_annotation
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _get_type_from_class_source(self, class_obj: type, field_name: str) -> str:
|
||||||
|
"""Extract type annotation from a specific class's source code."""
|
||||||
|
try:
|
||||||
|
source = inspect.getsource(class_obj)
|
||||||
|
tree = ast.parse(source)
|
||||||
|
except (OSError, TypeError):
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
# Find the class definition
|
||||||
|
for node in tree.body:
|
||||||
|
if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__:
|
||||||
|
# Find the field assignment
|
||||||
|
for body_node in node.body:
|
||||||
|
if isinstance(body_node, ast.AnnAssign) and isinstance(
|
||||||
|
body_node.target, ast.Name
|
||||||
|
):
|
||||||
|
if body_node.target.id == field_name and body_node.annotation:
|
||||||
|
return ast.unparse(body_node.annotation)
|
||||||
|
break
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _extract_field_groups_from_all_classes(
|
||||||
|
self, model_class: type[BaseModel]
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Extract field groups from all classes in the inheritance hierarchy."""
|
||||||
|
all_groups = []
|
||||||
|
inheritance_map = self._build_inheritance_map(model_class)
|
||||||
|
|
||||||
|
# Get all Pydantic base classes in MRO order (most specific first)
|
||||||
|
# This puts AxolotlInputConfig fields first, then parent class fields
|
||||||
|
pydantic_classes = [
|
||||||
|
cls
|
||||||
|
for cls in model_class.__mro__
|
||||||
|
if cls in inheritance_map and inheritance_map[cls]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Extract groups from each class
|
||||||
|
for cls in pydantic_classes:
|
||||||
|
class_groups = self._extract_field_groups_from_source(cls)
|
||||||
|
for group in class_groups:
|
||||||
|
all_groups.append(group)
|
||||||
|
|
||||||
|
# If no groups found, create a default grouping by class
|
||||||
|
if not all_groups:
|
||||||
|
for cls in pydantic_classes:
|
||||||
|
fields_in_class = inheritance_map[cls]
|
||||||
|
if fields_in_class:
|
||||||
|
all_groups.append(
|
||||||
|
{
|
||||||
|
"fields": list(fields_in_class),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_groups
|
||||||
|
|
||||||
|
# pylint: disable=too-many-return-statements
|
||||||
|
def _extract_field_groups_from_source(
|
||||||
|
self, model_class: type[BaseModel]
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Extract field groups from source code based on blank lines and comments."""
|
||||||
|
try:
|
||||||
|
source = inspect.getsource(model_class)
|
||||||
|
tree = ast.parse(source)
|
||||||
|
except (OSError, TypeError):
|
||||||
|
# Fallback if we can't get source code
|
||||||
|
fields_in_class = self._get_direct_fields(model_class)
|
||||||
|
if fields_in_class:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"fields": list(fields_in_class),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
groups = []
|
||||||
|
current_group_fields = []
|
||||||
|
current_group_comment = None
|
||||||
|
|
||||||
|
# Find the class definition
|
||||||
|
class_node = None
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.ClassDef) and node.name == model_class.__name__:
|
||||||
|
class_node = node
|
||||||
|
break
|
||||||
|
|
||||||
|
if not class_node:
|
||||||
|
fields_in_class = self._get_direct_fields(model_class)
|
||||||
|
if fields_in_class:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"fields": list(fields_in_class),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Parse the source lines to detect groupings
|
||||||
|
source_lines = source.split("\n")
|
||||||
|
|
||||||
|
# Get fields that are actually defined in this specific class
|
||||||
|
fields_in_class = self._get_direct_fields(model_class)
|
||||||
|
|
||||||
|
# Find assignments that correspond to model fields for THIS class only
|
||||||
|
field_assignments = []
|
||||||
|
for node in class_node.body:
|
||||||
|
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
|
||||||
|
field_name = node.target.id
|
||||||
|
if field_name in fields_in_class:
|
||||||
|
field_assignments.append(
|
||||||
|
{
|
||||||
|
"name": field_name,
|
||||||
|
"lineno": node.lineno,
|
||||||
|
"end_lineno": getattr(node, "end_lineno", node.lineno),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not field_assignments:
|
||||||
|
if fields_in_class:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"fields": list(fields_in_class),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort by line number
|
||||||
|
field_assignments.sort(key=lambda x: x["lineno"])
|
||||||
|
|
||||||
|
# Group fields based on blank lines and comments
|
||||||
|
for i, field_info in enumerate(field_assignments):
|
||||||
|
field_name = field_info["name"]
|
||||||
|
current_line = field_info["lineno"]
|
||||||
|
|
||||||
|
# Check if this starts a new group (blank line before or significant gap)
|
||||||
|
is_new_group = False
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
is_new_group = True
|
||||||
|
else:
|
||||||
|
prev_end_line = field_assignments[i - 1]["end_lineno"]
|
||||||
|
|
||||||
|
# Check for blank lines or comments between fields
|
||||||
|
lines_between = source_lines[prev_end_line : current_line - 1]
|
||||||
|
has_blank_line = any(line.strip() == "" for line in lines_between)
|
||||||
|
has_comment = any(
|
||||||
|
line.strip().startswith("#") for line in lines_between
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start new group if there's a blank line or comment, or significant gap
|
||||||
|
if has_blank_line or has_comment or (current_line - prev_end_line > 3):
|
||||||
|
is_new_group = True
|
||||||
|
|
||||||
|
if is_new_group and current_group_fields:
|
||||||
|
# Save the previous group
|
||||||
|
groups.append(
|
||||||
|
{
|
||||||
|
"fields": current_group_fields.copy(),
|
||||||
|
"description": current_group_comment,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_group_fields = []
|
||||||
|
current_group_comment = None
|
||||||
|
|
||||||
|
current_group_fields.append(field_name)
|
||||||
|
|
||||||
|
# Add the final group
|
||||||
|
if current_group_fields:
|
||||||
|
groups.append(
|
||||||
|
{
|
||||||
|
"fields": current_group_fields,
|
||||||
|
"description": current_group_comment,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return groups
|
||||||
|
|
||||||
|
def _generate_field_documentation(
|
||||||
|
self,
|
||||||
|
model_class: type[BaseModel],
|
||||||
|
field_name: str,
|
||||||
|
field_info: dict,
|
||||||
|
field_type_str: str,
|
||||||
|
is_required: bool,
|
||||||
|
indent_level: int = 0,
|
||||||
|
visited_models: set = None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate documentation for a single field, expanding nested models inline."""
|
||||||
|
if visited_models is None:
|
||||||
|
visited_models = set()
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
indent = " " * indent_level
|
||||||
|
|
||||||
|
# Get the actual field type for nested model detection
|
||||||
|
if field_name in model_class.model_fields:
|
||||||
|
pydantic_field_info = model_class.model_fields[field_name]
|
||||||
|
actual_field_type = pydantic_field_info.annotation
|
||||||
|
else:
|
||||||
|
actual_field_type = None
|
||||||
|
|
||||||
|
# Add description comment if available
|
||||||
|
description = field_info.get("description", "")
|
||||||
|
if description:
|
||||||
|
wrapped_lines = self._wrap_comment(description, width=88 - len(indent))
|
||||||
|
for line in wrapped_lines:
|
||||||
|
lines.append(f"{indent}{line}")
|
||||||
|
|
||||||
|
# Extract nested Pydantic models from the type annotation
|
||||||
|
nested_models = self._extract_all_pydantic_models_from_type(actual_field_type)
|
||||||
|
|
||||||
|
# Filter out already visited models to prevent infinite recursion
|
||||||
|
expandable_models = [
|
||||||
|
model for model in nested_models if model not in visited_models
|
||||||
|
]
|
||||||
|
|
||||||
|
if expandable_models:
|
||||||
|
# This field contains Pydantic models that can be expanded
|
||||||
|
|
||||||
|
# Show the field with its full type annotation
|
||||||
|
field_line = f"{indent}{field_name}: {field_type_str}"
|
||||||
|
if field_info.get("default") is not None:
|
||||||
|
field_line += f" = {field_info['default']}"
|
||||||
|
if is_required:
|
||||||
|
field_line += " (required)"
|
||||||
|
lines.append(field_line)
|
||||||
|
|
||||||
|
# Add to visited to prevent infinite recursion
|
||||||
|
new_visited = visited_models.copy()
|
||||||
|
new_visited.update(expandable_models)
|
||||||
|
|
||||||
|
# Expand each nested Pydantic model
|
||||||
|
for i, nested_model in enumerate(expandable_models):
|
||||||
|
if i > 0:
|
||||||
|
lines.append("\n")
|
||||||
|
lines.append(f"{indent} # For {nested_model.__name__}:")
|
||||||
|
|
||||||
|
# Get nested model schema
|
||||||
|
try:
|
||||||
|
nested_schema = nested_model.model_json_schema()
|
||||||
|
nested_properties = nested_schema.get("properties", {})
|
||||||
|
nested_required = nested_schema.get("required", [])
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
# Fallback: use model fields directly
|
||||||
|
nested_properties = {}
|
||||||
|
nested_required = []
|
||||||
|
for (
|
||||||
|
nested_field_name,
|
||||||
|
nested_field_info,
|
||||||
|
) in nested_model.model_fields.items():
|
||||||
|
nested_description = ""
|
||||||
|
if (
|
||||||
|
hasattr(nested_field_info, "json_schema_extra")
|
||||||
|
and nested_field_info.json_schema_extra
|
||||||
|
):
|
||||||
|
nested_description = (
|
||||||
|
nested_field_info.json_schema_extra.get(
|
||||||
|
"description", ""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
hasattr(nested_field_info, "description")
|
||||||
|
and nested_field_info.description
|
||||||
|
):
|
||||||
|
nested_description = nested_field_info.description
|
||||||
|
|
||||||
|
nested_default_val = None
|
||||||
|
if (
|
||||||
|
hasattr(nested_field_info, "default")
|
||||||
|
and nested_field_info.default is not None
|
||||||
|
):
|
||||||
|
if str(nested_field_info.default) != "PydanticUndefined":
|
||||||
|
nested_default_val = nested_field_info.default
|
||||||
|
|
||||||
|
nested_properties[nested_field_name] = {
|
||||||
|
"type": "unknown",
|
||||||
|
"description": nested_description,
|
||||||
|
"default": nested_default_val,
|
||||||
|
}
|
||||||
|
|
||||||
|
if nested_field_info.is_required():
|
||||||
|
nested_required.append(nested_field_name)
|
||||||
|
|
||||||
|
# Get field groups for the nested model
|
||||||
|
nested_field_groups = self._extract_field_groups_from_all_classes(
|
||||||
|
nested_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate nested fields with increased indentation
|
||||||
|
for i, group in enumerate(nested_field_groups):
|
||||||
|
if not group["fields"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add blank line between groups (except before first group)
|
||||||
|
if i > 0:
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# Process nested fields
|
||||||
|
for nested_field_name in group["fields"]:
|
||||||
|
if nested_field_name not in nested_properties:
|
||||||
|
continue
|
||||||
|
|
||||||
|
nested_field_info = nested_properties[nested_field_name]
|
||||||
|
nested_field_type = self._extract_type_from_source(
|
||||||
|
nested_model, nested_field_name
|
||||||
|
)
|
||||||
|
nested_is_required = nested_field_name in nested_required
|
||||||
|
|
||||||
|
# Recursively generate documentation for nested field
|
||||||
|
nested_lines = self._generate_field_documentation(
|
||||||
|
nested_model,
|
||||||
|
nested_field_name,
|
||||||
|
nested_field_info,
|
||||||
|
nested_field_type,
|
||||||
|
nested_is_required,
|
||||||
|
indent_level + 1,
|
||||||
|
new_visited,
|
||||||
|
)
|
||||||
|
lines.extend(nested_lines)
|
||||||
|
else:
|
||||||
|
# Regular field (no expandable nested models)
|
||||||
|
field_line = f"{indent}{field_name}: {field_type_str}"
|
||||||
|
if field_info.get("default") is not None:
|
||||||
|
field_line += f" = {field_info['default']}"
|
||||||
|
if is_required:
|
||||||
|
field_line += " (required)"
|
||||||
|
lines.append(field_line)
|
||||||
|
|
||||||
|
return lines
|
||||||
|
|
||||||
|
def generate_qmd(
|
||||||
|
self,
|
||||||
|
model_class: type[BaseModel],
|
||||||
|
title: str | None = None,
|
||||||
|
expand_nested: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Auto-generate config reference documentation including inherited fields."""
|
||||||
|
|
||||||
|
if title is None:
|
||||||
|
title = f"{model_class.__name__} Reference"
|
||||||
|
|
||||||
|
# Try to get JSON schema, with fallback for serialization issues
|
||||||
|
try:
|
||||||
|
schema = model_class.model_json_schema()
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
required = schema.get("required", [])
|
||||||
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
|
print(
|
||||||
|
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
|
||||||
|
)
|
||||||
|
# Fallback: use model fields directly
|
||||||
|
properties = {}
|
||||||
|
required = []
|
||||||
|
for field_name, field_info in model_class.model_fields.items():
|
||||||
|
# Extract description from json_schema_extra or field info
|
||||||
|
description = ""
|
||||||
|
if (
|
||||||
|
hasattr(field_info, "json_schema_extra")
|
||||||
|
and field_info.json_schema_extra
|
||||||
|
):
|
||||||
|
description = field_info.json_schema_extra.get("description", "")
|
||||||
|
elif hasattr(field_info, "description") and field_info.description:
|
||||||
|
description = field_info.description
|
||||||
|
|
||||||
|
# Get default value
|
||||||
|
default_val = None
|
||||||
|
if hasattr(field_info, "default") and field_info.default is not None:
|
||||||
|
# Handle special Pydantic default markers
|
||||||
|
if str(field_info.default) != "PydanticUndefined":
|
||||||
|
default_val = field_info.default
|
||||||
|
|
||||||
|
properties[field_name] = {
|
||||||
|
"type": "unknown",
|
||||||
|
"description": description,
|
||||||
|
"default": default_val,
|
||||||
|
}
|
||||||
|
|
||||||
|
if field_info.is_required():
|
||||||
|
required.append(field_name)
|
||||||
|
|
||||||
|
# Extract field groups from all classes in inheritance hierarchy
|
||||||
|
field_groups = self._extract_field_groups_from_all_classes(model_class)
|
||||||
|
|
||||||
|
# Start building QMD content
|
||||||
|
qmd_lines = [
|
||||||
|
"---",
|
||||||
|
f"title: {title}",
|
||||||
|
"description: A complete list of all configuration options.",
|
||||||
|
"---",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Generate one big code block with all fields (inline nested expansion)
|
||||||
|
qmd_lines.append("```yaml")
|
||||||
|
|
||||||
|
for i, group in enumerate(field_groups):
|
||||||
|
if not group["fields"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add blank line between groups (except before first group)
|
||||||
|
if i > 0:
|
||||||
|
qmd_lines.append("")
|
||||||
|
|
||||||
|
# Process fields in the order they appear in source
|
||||||
|
for field_name in group["fields"]:
|
||||||
|
if field_name not in properties:
|
||||||
|
continue
|
||||||
|
|
||||||
|
field_info = properties[field_name]
|
||||||
|
field_type = self._extract_type_from_source(model_class, field_name)
|
||||||
|
is_required = field_name in required
|
||||||
|
|
||||||
|
if expand_nested:
|
||||||
|
# Check if this field has nested models
|
||||||
|
if field_name in model_class.model_fields:
|
||||||
|
pydantic_field_info = model_class.model_fields[field_name]
|
||||||
|
nested_models = self._extract_all_pydantic_models_from_type(
|
||||||
|
pydantic_field_info.annotation
|
||||||
|
)
|
||||||
|
has_nested = bool(nested_models)
|
||||||
|
else:
|
||||||
|
has_nested = False
|
||||||
|
|
||||||
|
# Add blank line before nested config
|
||||||
|
if has_nested:
|
||||||
|
qmd_lines.append("")
|
||||||
|
|
||||||
|
# Use the new inline generation method
|
||||||
|
field_lines = self._generate_field_documentation(
|
||||||
|
model_class,
|
||||||
|
field_name,
|
||||||
|
field_info,
|
||||||
|
field_type,
|
||||||
|
is_required,
|
||||||
|
indent_level=0,
|
||||||
|
visited_models=set(),
|
||||||
|
)
|
||||||
|
qmd_lines.extend(field_lines)
|
||||||
|
|
||||||
|
# Add blank line after nested config
|
||||||
|
if has_nested:
|
||||||
|
qmd_lines.append("")
|
||||||
|
else:
|
||||||
|
# Original simple approach
|
||||||
|
description = field_info.get("description", "")
|
||||||
|
default = field_info.get("default")
|
||||||
|
|
||||||
|
# Add wrapped comment for description
|
||||||
|
if description:
|
||||||
|
wrapped_lines = self._wrap_comment(description)
|
||||||
|
qmd_lines.extend(wrapped_lines)
|
||||||
|
|
||||||
|
line = f"{field_name}: {field_type}"
|
||||||
|
if default is not None:
|
||||||
|
line += f" = {default}"
|
||||||
|
if is_required:
|
||||||
|
line += " (required)"
|
||||||
|
qmd_lines.append(line)
|
||||||
|
|
||||||
|
qmd_lines.append("```")
|
||||||
|
|
||||||
|
# Join all lines and clean up any double newlines
|
||||||
|
content = "\n".join(qmd_lines)
|
||||||
|
|
||||||
|
# Replace multiple consecutive newlines with just two newlines (one blank line)
|
||||||
|
import re
|
||||||
|
|
||||||
|
content = re.sub(r"\n{3,}", "\n\n", content)
|
||||||
|
|
||||||
|
# Ensure single newline at the very end
|
||||||
|
content = content.rstrip("\n") + "\n"
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
generator = QuartoGenerator()
|
||||||
|
|
||||||
|
print("Generating config reference content...")
|
||||||
|
qmd_content = generator.generate_qmd(AxolotlInputConfig, "Config Reference", True)
|
||||||
|
|
||||||
|
print("Writing to file...")
|
||||||
|
with open("docs/config-reference.qmd", "w", encoding="utf-8") as f:
|
||||||
|
f.write(qmd_content)
|
||||||
|
print("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
File diff suppressed because it is too large
Load Diff
71
examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-1.5B-Deep-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-1b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-1b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-1.5B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-34b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-34b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-34B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-3b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-3b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-3B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-500m-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-500m-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-0.5B-Instruct
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-7b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-7b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-7B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
@@ -13,6 +13,8 @@ load_in_4bit: true
|
|||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ load_in_4bit: true
|
|||||||
ddp_find_unused_parameters: true
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ sample_packing: false
|
|||||||
ddp_find_unused_parameters: true
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
55
examples/qwen2_5-vl/lora-7b.yaml
Normal file
55
examples/qwen2_5-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
BIN
favicon.jpg
BIN
favicon.jpg
Binary file not shown.
|
Before Width: | Height: | Size: 4.5 KiB After Width: | Height: | Size: 4.7 KiB |
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.45.4
|
bitsandbytes==0.46.0
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
@@ -11,14 +11,14 @@ liger-kernel==0.5.10
|
|||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub==0.32.2
|
huggingface_hub[hf_xet]==0.33.0
|
||||||
peft==0.15.2
|
peft==0.15.2
|
||||||
transformers==4.52.3
|
transformers==4.53.1
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.7.0
|
accelerate==1.8.1
|
||||||
datasets==3.6.0
|
datasets==3.6.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.18.1
|
trl==0.18.2
|
||||||
hf_xet==1.1.2
|
hf_xet==1.1.2
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
|||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.6.0
|
mistral-common==1.6.3
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
|
||||||
)
|
)
|
||||||
|
|||||||
14
setup.py
14
setup.py
@@ -65,15 +65,13 @@ def parse_requirements(extras_require_map):
|
|||||||
raise ValueError("Invalid version format")
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
if (major, minor) >= (2, 7):
|
if (major, minor) >= (2, 7):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append(
|
_install_requires.append(
|
||||||
"xformers==0.0.29.post2"
|
"xformers==0.0.29.post2"
|
||||||
) # vllm needs post2 w torch 2.6
|
) # vllm needs post2 w torch 2.6
|
||||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
extras_require_map["vllm"] = ["vllm==0.9.2"]
|
||||||
elif (major, minor) >= (2, 5):
|
elif (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -111,14 +109,14 @@ def get_package_version():
|
|||||||
|
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
"flash-attn": ["flash-attn==2.8.0.post2"],
|
||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.7.4.post1",
|
"flash-attn==2.8.0.post2",
|
||||||
"ring-flash-attn>=0.1.4",
|
"ring-flash-attn>=0.1.5",
|
||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.17.0",
|
"deepspeed==0.17.1",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.10.0"
|
__version__ = "0.11.0.dev"
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -46,3 +47,8 @@ def check_user_token() -> bool:
|
|||||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
except HTTPError:
|
||||||
|
LOG.warning(
|
||||||
|
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -24,7 +23,6 @@ def do_cli_preprocess(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -39,7 +37,6 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -54,7 +51,6 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
|||||||
@@ -26,7 +26,9 @@ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
|||||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
API_KEY_FIELDS = {"comet_api_key"}
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
@@ -233,4 +235,15 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
|
cfg_to_log = {
|
||||||
|
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||||
|
for k, v in cfg.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"config:\n%s",
|
||||||
|
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||||
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.cli.args import InferenceCliArgs
|
from axolotl.cli.args import InferenceCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -255,7 +254,6 @@ def do_cli(
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from axolotl.cli.args import (
|
|||||||
TrainerCliArgs,
|
TrainerCliArgs,
|
||||||
VllmServeCliArgs,
|
VllmServeCliArgs,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -40,6 +41,7 @@ LOG = get_logger(__name__)
|
|||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Union
|
|||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -23,8 +22,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -194,7 +193,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
from axolotl.cli.args import PreprocessCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -33,10 +32,15 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Preprocessing-specific CLI arguments.
|
cli_args: Preprocessing-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
|
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
|
||||||
|
if cfg.get("key"):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
|
||||||
|
)
|
||||||
|
|
||||||
if not cfg.dataset_prepared_path:
|
if not cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
Fore.RED
|
Fore.RED
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -27,7 +26,6 @@ def do_quantize(
|
|||||||
config (Union[Path, str]): The path to the config file
|
config (Union[Path, str]): The path to the config file
|
||||||
cli_args (dict): Additional command-line arguments
|
cli_args (dict): Additional command-line arguments
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -75,13 +75,17 @@ def load_datasets(
|
|||||||
|
|
||||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||||
text_only = cli_args.debug_text_only if cli_args else False
|
text_only = cli_args.debug_text_only if cli_args else False
|
||||||
train_samples = sample_dataset(train_dataset, num_examples)
|
try:
|
||||||
check_dataset_labels(
|
train_samples = sample_dataset(train_dataset, num_examples)
|
||||||
train_samples,
|
check_dataset_labels(
|
||||||
tokenizer,
|
train_samples,
|
||||||
num_examples=num_examples,
|
tokenizer,
|
||||||
text_only=text_only,
|
num_examples=num_examples,
|
||||||
)
|
text_only=text_only,
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
# can't sample iterable datasets
|
||||||
|
pass
|
||||||
|
|
||||||
LOG.info("printing prompters...")
|
LOG.info("printing prompters...")
|
||||||
for prompter in prompters:
|
for prompter in prompters:
|
||||||
|
|||||||
162
src/axolotl/core/attention/flex_block_mask.py
Normal file
162
src/axolotl/core/attention/flex_block_mask.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
monkeypatch for flex + packing
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.attention.flex_attention import BlockMask
|
||||||
|
from transformers import Cache, PretrainedConfig
|
||||||
|
from transformers.masking_utils import (
|
||||||
|
ALL_MASK_ATTENTION_FUNCTIONS,
|
||||||
|
_preprocess_mask_arguments,
|
||||||
|
and_masks,
|
||||||
|
causal_mask_function,
|
||||||
|
or_masks,
|
||||||
|
)
|
||||||
|
from transformers.utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
|
_is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
|
||||||
|
|
||||||
|
|
||||||
|
def create_causal_mask(
|
||||||
|
config: PretrainedConfig,
|
||||||
|
input_embeds: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Optional[Cache],
|
||||||
|
or_mask_function: Optional[Callable] = None,
|
||||||
|
and_mask_function: Optional[Callable] = None,
|
||||||
|
) -> Optional[Union[torch.Tensor, BlockMask]]:
|
||||||
|
"""
|
||||||
|
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
|
||||||
|
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
|
||||||
|
to what is needed in the `modeling_xxx.py` files).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (`PretrainedConfig`):
|
||||||
|
The model config.
|
||||||
|
input_embeds (`torch.Tensor`):
|
||||||
|
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
|
||||||
|
batch size, query length and dtype.
|
||||||
|
attention_mask (`torch.Tensor`, optional):
|
||||||
|
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
|
||||||
|
It can also be an already prepared 4D mask, in which case it is returned as-is.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
|
||||||
|
past_key_values (`Cache`, optional):
|
||||||
|
The past key values, if we use a cache.
|
||||||
|
or_mask_function (`Callable`, optional):
|
||||||
|
An optional mask function to combine with the causal mask function (by doing the union of both). This is
|
||||||
|
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||||
|
and_mask_function (`Callable`, optional):
|
||||||
|
An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
|
||||||
|
useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
|
||||||
|
"""
|
||||||
|
# If we have an HybridCache structure, here we want to create the mask for the full layers
|
||||||
|
if (
|
||||||
|
past_key_values
|
||||||
|
and hasattr(past_key_values, "is_sliding")
|
||||||
|
and False in past_key_values.is_sliding
|
||||||
|
):
|
||||||
|
layer_idx = past_key_values.is_sliding.index(False)
|
||||||
|
else:
|
||||||
|
layer_idx = 0
|
||||||
|
|
||||||
|
original_attention_mask = (
|
||||||
|
None
|
||||||
|
if attention_mask is None
|
||||||
|
else attention_mask.clone().to(cache_position.device)
|
||||||
|
)
|
||||||
|
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
|
||||||
|
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
|
||||||
|
)
|
||||||
|
if early_exit:
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
batch_size, total_seq_len = cache_position.shape
|
||||||
|
key_length = total_seq_len
|
||||||
|
document_ids = torch.nn.functional.pad(
|
||||||
|
original_attention_mask, value=0, pad=(0, key_length)
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
|
||||||
|
if attention_mask is not None:
|
||||||
|
|
||||||
|
def causal_doc_mask_mod(
|
||||||
|
batch_idx, head_idx, q_idx, kv_idx
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
"""
|
||||||
|
Defines the logic of a block causal mask by combining both a standard causal mask
|
||||||
|
and a block diagonal document mask.
|
||||||
|
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
|
||||||
|
for an illustration.
|
||||||
|
"""
|
||||||
|
causal_mask_ = q_idx >= kv_idx # not valid when decoding
|
||||||
|
document_mask = (
|
||||||
|
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
|
||||||
|
)
|
||||||
|
final_mask = causal_mask_ & document_mask
|
||||||
|
return final_mask
|
||||||
|
|
||||||
|
mask_factory_function = causal_doc_mask_mod
|
||||||
|
else:
|
||||||
|
mask_factory_function = causal_mask_function
|
||||||
|
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
|
||||||
|
config._attn_implementation # pylint: disable=protected-access
|
||||||
|
]
|
||||||
|
|
||||||
|
# Do not allow skip if we are compiling (this is to match BC)
|
||||||
|
allow_is_causal_skip = (
|
||||||
|
not past_key_values.is_compileable if past_key_values is not None else True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow slight deviations from causal mask
|
||||||
|
if or_mask_function is not None:
|
||||||
|
if not _is_torch_greater_or_equal_than_2_6:
|
||||||
|
raise ValueError(
|
||||||
|
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||||
|
)
|
||||||
|
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
|
||||||
|
allow_is_causal_skip = False
|
||||||
|
if and_mask_function is not None:
|
||||||
|
if not _is_torch_greater_or_equal_than_2_6:
|
||||||
|
raise ValueError(
|
||||||
|
"Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6"
|
||||||
|
)
|
||||||
|
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
|
||||||
|
allow_is_causal_skip = False
|
||||||
|
|
||||||
|
# We now create the mask
|
||||||
|
causal_mask = mask_interface(
|
||||||
|
batch_size=batch_size,
|
||||||
|
cache_position=cache_position,
|
||||||
|
kv_length=kv_length,
|
||||||
|
kv_offset=kv_offset,
|
||||||
|
mask_function=mask_factory_function,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
|
||||||
|
dtype=dtype, # Additional kwarg for eager
|
||||||
|
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
|
||||||
|
)
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
def patch_create_causal_mask(model_type):
|
||||||
|
import transformers.masking_utils
|
||||||
|
|
||||||
|
transformers.masking_utils.create_causal_mask = create_causal_mask
|
||||||
|
|
||||||
|
if model_type:
|
||||||
|
try:
|
||||||
|
# Dynamically import the module and attention class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
module = __import__(module_path)
|
||||||
|
module.create_causal_mask = create_causal_mask
|
||||||
|
del sys.modules[module_path]
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not import attention class for model_type: {model_type}. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
) from e
|
||||||
@@ -219,7 +219,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_args_kwargs["bf16_full_eval"] = True
|
training_args_kwargs["bf16_full_eval"] = True
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
bf16 = self.cfg.bf16 or self.cfg.bfloat16
|
||||||
|
bf16 = bf16 if bf16 is not None else False
|
||||||
|
training_args_kwargs["bf16"] = bf16
|
||||||
|
|
||||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||||
|
|||||||
@@ -245,14 +245,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
|
|
||||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||||
|
training_arguments_kwargs["sample_packing_drop_attention_mask"] = bool(
|
||||||
|
self.cfg.flash_attention
|
||||||
|
or self.cfg.xformers_attention
|
||||||
|
or self.cfg.flex_attention
|
||||||
|
)
|
||||||
training_arguments_kwargs["multipack_real_batches"] = (
|
training_arguments_kwargs["multipack_real_batches"] = (
|
||||||
self.cfg.multipack_real_batches
|
self.cfg.multipack_real_batches
|
||||||
if self.cfg.multipack_real_batches is not None
|
if self.cfg.multipack_real_batches is not None
|
||||||
else not self.cfg.flash_attention
|
else not (
|
||||||
|
self.cfg.flash_attention
|
||||||
|
or self.cfg.flex_attention
|
||||||
|
or self.cfg.xformers_attention
|
||||||
|
)
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
|
if self.cfg.sample_packing_sequentially is not None:
|
||||||
|
training_arguments_kwargs["sample_packing_sequentially"] = (
|
||||||
|
self.cfg.sample_packing_sequentially
|
||||||
|
)
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||||
self.cfg.sample_packing_bin_size
|
self.cfg.sample_packing_bin_size
|
||||||
@@ -413,7 +426,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
or self.cfg.micro_batch_size > 1
|
or self.cfg.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||||
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|||||||
@@ -20,13 +20,14 @@ from torch.utils.data import (
|
|||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
CheckpointSaveMixin,
|
CheckpointSaveMixin,
|
||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
|
PackingMixin,
|
||||||
RngLoaderMixin,
|
RngLoaderMixin,
|
||||||
SchedulerMixin,
|
SchedulerMixin,
|
||||||
)
|
)
|
||||||
@@ -42,7 +43,12 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(
|
class AxolotlTrainer(
|
||||||
SchedulerMixin, OptimizerMixin, RngLoaderMixin, CheckpointSaveMixin, Trainer
|
PackingMixin,
|
||||||
|
SchedulerMixin,
|
||||||
|
OptimizerMixin,
|
||||||
|
RngLoaderMixin,
|
||||||
|
CheckpointSaveMixin,
|
||||||
|
Trainer,
|
||||||
):
|
):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
@@ -116,14 +122,15 @@ class AxolotlTrainer(
|
|||||||
sequential=self.args.sample_packing_sequentially,
|
sequential=self.args.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=self.args.dataset_num_proc,
|
num_processes=self.args.dataset_num_proc,
|
||||||
|
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
len(sampler)
|
len(sampler)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
def _get_train_sampler(
|
def _get_train_sampler(
|
||||||
self, train_dataset: Optional[Dataset] = None
|
self, train_dataset: Dataset | None = None
|
||||||
) -> Optional[Sampler]:
|
) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sample packing
|
Helper method to get the sampler for training. Handles cases for sample packing
|
||||||
and curriculum sampling (sequential).
|
and curriculum sampling (sequential).
|
||||||
@@ -132,16 +139,22 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
|
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
|
||||||
|
if train_dataset is None:
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
if train_dataset is None or not has_length(train_dataset):
|
||||||
|
return None
|
||||||
|
|
||||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
# Determine the base sampler first
|
# Determine the base sampler first
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
base_sampler = SequentialSampler(train_dataset)
|
||||||
elif use_sample_packing:
|
elif use_sample_packing:
|
||||||
base_sampler = RandomSampler(self.train_dataset)
|
base_sampler = RandomSampler(train_dataset)
|
||||||
else:
|
else:
|
||||||
# Default to parent class implementation for standard random sampling
|
# Default to parent class implementation for standard random sampling
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler(train_dataset)
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
# Apply multipack wrapper if needed
|
||||||
if use_sample_packing:
|
if use_sample_packing:
|
||||||
@@ -160,6 +173,10 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
|
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
|
||||||
|
if eval_dataset is None or not has_length(eval_dataset):
|
||||||
|
return None
|
||||||
|
|
||||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
use_multipack = (
|
use_multipack = (
|
||||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
@@ -195,6 +212,14 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
if dataset.column_names and "length" in dataset.column_names:
|
if dataset.column_names and "length" in dataset.column_names:
|
||||||
dataset = dataset.remove_columns(["length"])
|
dataset = dataset.remove_columns(["length"])
|
||||||
|
if (
|
||||||
|
dataset.column_names
|
||||||
|
and "position_ids" in dataset.column_names
|
||||||
|
and "attention_mask" in dataset.column_names
|
||||||
|
and self.args.sample_packing
|
||||||
|
and self.args.sample_packing_drop_attention_mask
|
||||||
|
):
|
||||||
|
dataset = dataset.remove_columns(["attention_mask"])
|
||||||
|
|
||||||
if isinstance(dataset, datasets.Dataset):
|
if isinstance(dataset, datasets.Dataset):
|
||||||
if is_training:
|
if is_training:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
if cfg.dpo_padding_free is not None:
|
if cfg.dpo_padding_free is not None:
|
||||||
|
|||||||
@@ -5,5 +5,6 @@
|
|||||||
|
|
||||||
from .checkpoints import CheckpointSaveMixin
|
from .checkpoints import CheckpointSaveMixin
|
||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
|
from .packing import PackingMixin
|
||||||
from .rng_state_loader import RngLoaderMixin
|
from .rng_state_loader import RngLoaderMixin
|
||||||
from .scheduler import SchedulerMixin
|
from .scheduler import SchedulerMixin
|
||||||
|
|||||||
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
20
src/axolotl/core/trainers/mixins/packing.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""Trainer mixin to support packing"""
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
class PackingMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Trainer mixin to support packing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _set_signature_columns_if_needed(self):
|
||||||
|
super()._set_signature_columns_if_needed()
|
||||||
|
if (
|
||||||
|
self._signature_columns
|
||||||
|
and self.args.sample_packing
|
||||||
|
and self.args.sample_packing_drop_attention_mask
|
||||||
|
):
|
||||||
|
set_sig_columns = set(self._signature_columns)
|
||||||
|
set_sig_columns.remove("attention_mask")
|
||||||
|
self._signature_columns = list(set_sig_columns)
|
||||||
@@ -38,6 +38,14 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The multiprocessing start method to use."},
|
||||||
|
)
|
||||||
|
sample_packing_drop_attention_mask: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Drop attention mask from inputs when using packing."},
|
||||||
|
)
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from transformers import PreTrainedModel, Trainer
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
|
|||||||
@@ -19,19 +19,11 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
|
|
||||||
|
|
||||||
pip3 install --no-build-isolation -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
@@ -39,27 +31,29 @@ plugins:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- llama
|
- cohere
|
||||||
- llama4
|
- cohere2
|
||||||
- llama4_text
|
|
||||||
- mllama
|
|
||||||
- phi3
|
|
||||||
- gemma
|
- gemma
|
||||||
- gemma2
|
- gemma2
|
||||||
- gemma3
|
- gemma3
|
||||||
- gemma3_text
|
- gemma3_text
|
||||||
|
- glm
|
||||||
|
- glm4
|
||||||
|
- llama
|
||||||
|
- llama4
|
||||||
|
- llama4_text
|
||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
|
- mllama
|
||||||
|
- phi
|
||||||
|
- phi3
|
||||||
|
- phi4_multimodal
|
||||||
- qwen2
|
- qwen2
|
||||||
- qwen2_moe
|
|
||||||
- qwen2_vl
|
- qwen2_vl
|
||||||
|
- qwen2_moe
|
||||||
- qwen2_5_vl
|
- qwen2_5_vl
|
||||||
- qwen3
|
- qwen3
|
||||||
- qwen3_moe
|
- qwen3_moe
|
||||||
- cohere
|
|
||||||
- cohere2
|
|
||||||
- glm
|
|
||||||
- glm4
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
@@ -28,11 +28,11 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -64,16 +64,28 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"cut_cross_entropy.transformers"
|
"cut_cross_entropy.transformers"
|
||||||
)
|
)
|
||||||
if cce_spec_transformers is None:
|
if cce_spec_transformers is None:
|
||||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
raise ImportError(
|
||||||
|
"Transformers support is not installed. " + _CCE_INSTALL_MESSAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if Axolotl's cce fork is installed
|
||||||
|
try:
|
||||||
|
from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK
|
||||||
|
|
||||||
|
if not AXOLOTL_CCE_FORK:
|
||||||
|
raise ImportError
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Axolotl's fork of cut_cross_entropy is not installed. "
|
||||||
|
+ _CCE_INSTALL_MESSAGE
|
||||||
|
) from e
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
"""Apply cut cross entropy before model loading if enabled."""
|
"""Apply cut cross entropy before model loading if enabled."""
|
||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
from cut_cross_entropy.transformers.patch import cce_patch
|
||||||
cce_patch,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||||
|
|||||||
@@ -1,191 +0,0 @@
|
|||||||
"""Cohere and Cohere2 CCE patch."""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
|
||||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
|
||||||
# operation is done internally and should be mathematically equivalent.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
|
||||||
|
|
||||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
|
|
||||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>> # Generate
|
|
||||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
# scale hidden_states by logit_scale in-place of logits
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :] * self.logit_scale,
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
logits = logits * self.logit_scale # main diff from Llama
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere import modeling_cohere
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere.CohereForCausalLM
|
|
||||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere2 import modeling_cohere2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
|
||||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Gemma CCE patch"""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
||||||
|
|
||||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma import modeling_gemma
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
|
||||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,447 +0,0 @@
|
|||||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
|
||||||
|
|
||||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
|
||||||
# and updated for transformers 4.50.0.
|
|
||||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
|
||||||
# with both gemma3 (text and multimodal) models.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache, HybridCache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
from transformers.utils import (
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[HybridCache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
|
||||||
|
|
||||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logits = logits / self.config.final_logit_softcapping
|
|
||||||
logits = torch.tanh(logits)
|
|
||||||
logits = logits * self.config.final_logit_softcapping
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
|
|
||||||
>>> prompt = "answer en Where is the cow standing?"
|
|
||||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"answer en Where is the cow standing?\nbeach"
|
|
||||||
```"""
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
is_training = token_type_ids is not None and labels is not None
|
|
||||||
|
|
||||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
|
||||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_index
|
|
||||||
llm_input_ids = input_ids.clone()
|
|
||||||
llm_input_ids[special_image_mask] = 0
|
|
||||||
else:
|
|
||||||
llm_input_ids = input_ids # type: ignore
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
||||||
|
|
||||||
if cache_position is None:
|
|
||||||
past_seen_tokens = (
|
|
||||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
|
||||||
)
|
|
||||||
cache_position = torch.arange( # type: ignore
|
|
||||||
past_seen_tokens,
|
|
||||||
past_seen_tokens + inputs_embeds.shape[1],
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge text and images
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(pixel_values)
|
|
||||||
|
|
||||||
if input_ids is None:
|
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
||||||
torch.tensor(
|
|
||||||
self.config.image_token_index,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
|
||||||
-1
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of images does not match number of special image tokens in the input text. "
|
|
||||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
|
||||||
"tokens from image embeddings."
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
# mask out pad-token-ids in labels for BC
|
|
||||||
if labels is not None and self.pad_token_id in labels:
|
|
||||||
logger.warning_once(
|
|
||||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
|
||||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
|
||||||
)
|
|
||||||
labels = torch.where( # type: ignore
|
|
||||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
|
||||||
)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
|
||||||
attention_mask,
|
|
||||||
token_type_ids,
|
|
||||||
past_key_values,
|
|
||||||
cache_position,
|
|
||||||
inputs_embeds,
|
|
||||||
is_training,
|
|
||||||
)
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
shift_logits = logits[..., :-1, :]
|
|
||||||
shift_labels = labels[..., 1:]
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = shift_logits[
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = shift_labels[
|
|
||||||
shift_attention_mask.to(shift_labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = shift_logits.contiguous()
|
|
||||||
shift_labels = shift_labels.contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
||||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
||||||
loss = loss_fct(flat_logits, flat_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Gemma3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma2 import modeling_gemma2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
|
||||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3_text(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
|
||||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
|
||||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
"""GLM 4 patch. GLM family inherits from Llama."""
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_glm(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import cut_cross_entropy.transformers.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
|
||||||
from transformers.models.glm import modeling_glm
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_glm.GlmForCausalLM
|
|
||||||
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_glm.GlmForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_glm4(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import cut_cross_entropy.transformers.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
|
||||||
from transformers.models.glm4 import modeling_glm4
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_glm4.Glm4ForCausalLM
|
|
||||||
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import (
|
|
||||||
BaseModelOutputWithPast,
|
|
||||||
CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
from transformers.utils.generic import can_return_tuple
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@can_return_tuple
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Cache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> CausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs: BaseModelOutputWithPast = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
if hidden_states is None:
|
|
||||||
raise ValueError("hidden_states is None")
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
"""Patch Llama for CCE."""
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.llama import modeling_llama
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_llama.LlamaForCausalLM
|
|
||||||
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,401 +0,0 @@
|
|||||||
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.llama4.modeling_llama4 import (
|
|
||||||
Llama4CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
|
||||||
|
|
||||||
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
||||||
vision_feature_select_strategy: Optional[str] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
image_sizes: torch.Tensor | None = None,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
vision_feature_layer = (
|
|
||||||
vision_feature_layer
|
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_config.vision_feature_layer
|
|
||||||
)
|
|
||||||
vision_feature_select_strategy = (
|
|
||||||
vision_feature_select_strategy
|
|
||||||
if vision_feature_select_strategy is not None
|
|
||||||
else self.config.vision_config.vision_feature_select_strategy
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
vision_feature_layer=vision_feature_layer,
|
|
||||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
)
|
|
||||||
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
|
||||||
|
|
||||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
|
||||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
||||||
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
|
||||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
|
||||||
|
|
||||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
|
||||||
num_tokens_to_fill = final_mask_1d.sum()
|
|
||||||
|
|
||||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
|
||||||
raise ValueError(
|
|
||||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
|
||||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
|
||||||
expanded_mask, projected_vision_flat
|
|
||||||
) # type: ignore
|
|
||||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
# TODO: check if need to handle attention_mask
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = logits[..., :-1, :][
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = labels[..., 1:][
|
|
||||||
shift_attention_mask.to(labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1).to(shift_logits.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Llama4CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits, # type: ignore # TODO: check if need to create dummy logits
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama4_text(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.llama4 import modeling_llama4
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_llama4.Llama4ForCausalLM
|
|
||||||
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
modeling_llama4.Llama4ForCausalLM,
|
|
||||||
"forward",
|
|
||||||
cce_forward,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama4(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.llama4 import modeling_llama4
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
|
|
||||||
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the language model
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
modeling_llama4.Llama4ForConditionalGeneration,
|
|
||||||
"forward",
|
|
||||||
cce_forward_multimodal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# patch the causal language model
|
|
||||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
|
||||||
return None
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
"""Mistral and Mistral3 CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mistral3.modeling_mistral3 import (
|
|
||||||
Mistral3CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] | None = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
|
||||||
|
|
||||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
image_sizes: torch.Tensor | None = None,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
|
|
||||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
vision_feature_layer = (
|
|
||||||
vision_feature_layer
|
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_feature_layer
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
vision_feature_layer=vision_feature_layer,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
)
|
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
||||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = logits[..., :-1, :][
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = labels[..., 1:][
|
|
||||||
shift_attention_mask.to(labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1).to(shift_logits.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Mistral3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral.MistralForCausalLM
|
|
||||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
from transformers.models.mistral3 import modeling_mistral3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
|
||||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,366 +0,0 @@
|
|||||||
"""Mllama CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mllama.modeling_mllama import (
|
|
||||||
_prepare_cross_attention_mask,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
|
||||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
|
|
||||||
>>> prompt = "If I had to write a haiku, it would be:"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
|
||||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
>>> print(result)
|
|
||||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
|
||||||
I love the idea of snowflakes gently falling, each one
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
|
||||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
|
||||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
|
||||||
|
|
||||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
|
||||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
|
||||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
|
||||||
|
|
||||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
|
|
||||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
|
||||||
>>> generated_ids = output[:, prompt_len:]
|
|
||||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
||||||
>>> print(generated_text)
|
|
||||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None and cross_attention_states is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
if aspect_ratio_ids is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
|
||||||
)
|
|
||||||
# get vision tokens from vision model
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
aspect_ratio_ids=aspect_ratio_ids,
|
|
||||||
aspect_ratio_mask=aspect_ratio_mask,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
cross_attention_states = vision_outputs[0]
|
|
||||||
cross_attention_states = self.multi_modal_projector(
|
|
||||||
cross_attention_states
|
|
||||||
).reshape(
|
|
||||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
if cross_attention_mask is not None:
|
|
||||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
|
||||||
_prepare_cross_attention_mask(
|
|
||||||
cross_attention_mask,
|
|
||||||
num_vision_tokens=self.vision_model.num_patches,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
full_text_row_masked_out_mask = None
|
|
||||||
|
|
||||||
if cross_attention_mask is not None and cache_position is not None:
|
|
||||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
|
||||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
|
||||||
:, :, cache_position
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
|
||||||
logits = hidden_states
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (loss,) + outputs if loss is not None else outputs
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=outputs.logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mllama(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mllama import modeling_mllama
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
|
||||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the language model
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
|
|
||||||
# patch the causal language model
|
|
||||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,126 +0,0 @@
|
|||||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
"""Cut Cross Entropy patcher"""
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
|
||||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
|
||||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
|
||||||
patch_cohere,
|
|
||||||
patch_cohere2,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
|
||||||
patch_gemma2,
|
|
||||||
patch_gemma3,
|
|
||||||
patch_gemma3_text,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
|
||||||
patch_glm,
|
|
||||||
patch_glm4,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
|
||||||
patch_llama,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
|
||||||
patch_llama4,
|
|
||||||
patch_llama4_text,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
|
||||||
patch_mistral,
|
|
||||||
patch_mistral3,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
|
|
||||||
patch_qwen2,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
|
|
||||||
patch_qwen2_5_vl,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
|
|
||||||
patch_qwen2_moe,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
|
|
||||||
patch_qwen2_vl,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
|
|
||||||
patch_qwen3_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
|
||||||
"llama": patch_llama,
|
|
||||||
"llama4": patch_llama4,
|
|
||||||
"llama4_text": patch_llama4_text,
|
|
||||||
"mllama": patch_mllama,
|
|
||||||
"phi3": patch_phi3,
|
|
||||||
"gemma": patch_gemma,
|
|
||||||
"gemma2": patch_gemma2,
|
|
||||||
"gemma3": patch_gemma3,
|
|
||||||
"gemma3_text": patch_gemma3_text,
|
|
||||||
"mistral": patch_mistral,
|
|
||||||
"mistral3": patch_mistral3,
|
|
||||||
"qwen2": patch_qwen2,
|
|
||||||
"qwen2_moe": patch_qwen2_moe,
|
|
||||||
"qwen2_vl": patch_qwen2_vl,
|
|
||||||
"qwen2_5_vl": patch_qwen2_5_vl,
|
|
||||||
"qwen3": patch_qwen3,
|
|
||||||
"qwen3_moe": patch_qwen3_moe,
|
|
||||||
"cohere": patch_cohere,
|
|
||||||
"cohere2": patch_cohere2,
|
|
||||||
"glm": patch_glm,
|
|
||||||
"glm4": patch_glm4,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def cce_patch(
|
|
||||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
|
||||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
||||||
reduction: str = "mean",
|
|
||||||
filter_eps: float | str | None = "auto",
|
|
||||||
accum_e_fp32: bool = False,
|
|
||||||
accum_c_fp32: bool = False,
|
|
||||||
filter_e_grad: bool = True,
|
|
||||||
filter_c_grad: bool = True,
|
|
||||||
train_only: bool = False,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
if isinstance(impl, LinearCrossEntropyImpl):
|
|
||||||
impl = impl.name.lower()
|
|
||||||
|
|
||||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
|
||||||
raise ValueError(f"Unknown {impl=}")
|
|
||||||
|
|
||||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
|
||||||
if hasattr(model_type_or_model, "config"):
|
|
||||||
model_type = getattr(
|
|
||||||
getattr(model_type_or_model, "config", None), "model_type", None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
|
|
||||||
)
|
|
||||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
|
||||||
model_type = model_type_or_model.model_type
|
|
||||||
else:
|
|
||||||
model_type = model_type_or_model
|
|
||||||
|
|
||||||
patch_options = PatchOptions(
|
|
||||||
impl=impl,
|
|
||||||
reduction=reduction,
|
|
||||||
filter_eps=filter_eps,
|
|
||||||
accum_e_fp32=accum_e_fp32,
|
|
||||||
accum_c_fp32=accum_c_fp32,
|
|
||||||
filter_e_grad=filter_e_grad,
|
|
||||||
filter_c_grad=filter_c_grad,
|
|
||||||
train_only=train_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
|
||||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
|
||||||
model_type_or_model, patch_options
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError(f"Unknown model type {model_type}")
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
from transformers.models.qwen2 import modeling_qwen2
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
|
||||||
cce_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2.Qwen2ForCausalLM
|
|
||||||
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,246 +0,0 @@
|
|||||||
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
||||||
Qwen2_5_VLCausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
|
||||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
||||||
|
|
||||||
>>> messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image"},
|
|
||||||
{"type": "text", "text": "What is shown in this image?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
||||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.type(self.visual.dtype)
|
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
||||||
n_image_features = image_embeds.shape[0]
|
|
||||||
if n_image_tokens != n_image_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
|
|
||||||
mask = input_ids == self.config.image_token_id
|
|
||||||
mask_unsqueezed = mask.unsqueeze(-1)
|
|
||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
||||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
||||||
n_video_features = video_embeds.shape[0]
|
|
||||||
if n_video_tokens != n_video_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
||||||
)
|
|
||||||
|
|
||||||
mask = input_ids == self.config.video_token_id
|
|
||||||
mask_unsqueezed = mask.unsqueeze(-1)
|
|
||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
||||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
||||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
|
||||||
if (
|
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
|
||||||
or self.rope_deltas is None
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
|
||||||
):
|
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
|
||||||
input_ids,
|
|
||||||
image_grid_thw,
|
|
||||||
video_grid_thw,
|
|
||||||
second_per_grid_ts,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
self.rope_deltas = rope_deltas
|
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
||||||
else:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
delta = (
|
|
||||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
||||||
if cache_position is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
|
||||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
||||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
|
||||||
position_ids = position_ids.add(delta) # type: ignore
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=None,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Qwen2_5_VLCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
rope_deltas=self.rope_deltas,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2_5_vl(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
|
||||||
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
|
|
||||||
cce_forward_multimodal
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
|
||||||
MoeCausalLMOutputWithPast,
|
|
||||||
MoeModelOutputWithPast,
|
|
||||||
load_balancing_loss_func,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
from transformers.utils.generic import can_return_tuple
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@can_return_tuple
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> MoeCausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
|
|
||||||
|
|
||||||
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs: MoeModelOutputWithPast = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if hidden_states is None:
|
|
||||||
raise ValueError("hidden_states is None")
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(
|
|
||||||
outputs.router_logits,
|
|
||||||
self.num_experts,
|
|
||||||
self.num_experts_per_tok,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
|
||||||
loss.device # type: ignore
|
|
||||||
) # make sure to reside in the same device
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss, # type: ignore
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
router_logits=outputs.router_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2_moe(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
|
|
||||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(forward, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
|
|
||||||
return None
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
||||||
Qwen2VLCausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
|
||||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
||||||
|
|
||||||
>>> messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image"},
|
|
||||||
{"type": "text", "text": "What is shown in this image?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
||||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
||||||
n_image_features = image_embeds.shape[0]
|
|
||||||
if n_image_tokens != n_image_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
image_mask = (
|
|
||||||
(input_ids == self.config.image_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
||||||
n_video_features = video_embeds.shape[0]
|
|
||||||
if n_video_tokens != n_video_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
||||||
)
|
|
||||||
video_mask = (
|
|
||||||
(input_ids == self.config.video_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
||||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
|
||||||
if (
|
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
|
||||||
or self.rope_deltas is None
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
|
||||||
):
|
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
|
||||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
||||||
)
|
|
||||||
self.rope_deltas = rope_deltas
|
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
||||||
else:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
delta = (
|
|
||||||
cache_position[0] + self.rope_deltas
|
|
||||||
if cache_position is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
|
||||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
||||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
|
||||||
delta = delta.to(position_ids.device) # type: ignore
|
|
||||||
position_ids = position_ids.add(delta) # type: ignore
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=None,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Qwen2VLCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
rope_deltas=self.rope_deltas,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2_vl(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
|
|
||||||
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
return None
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
from transformers.models.qwen3 import modeling_qwen3
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen3.Qwen3ForCausalLM
|
|
||||||
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,183 +0,0 @@
|
|||||||
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
MoeCausalLMOutputWithPast,
|
|
||||||
MoeModelOutputWithPast,
|
|
||||||
load_balancing_loss_func,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
from transformers.utils.generic import can_return_tuple
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@can_return_tuple
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> MoeCausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
|
||||||
|
|
||||||
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs: MoeModelOutputWithPast = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
|
|
||||||
if hidden_states is None:
|
|
||||||
raise ValueError("hidden_states is None")
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(
|
|
||||||
outputs.router_logits,
|
|
||||||
self.num_experts,
|
|
||||||
self.num_experts_per_tok,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
|
||||||
loss.device # type: ignore
|
|
||||||
) # make sure to reside in the same device
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss, # type: ignore
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
router_logits=outputs.router_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen3_moe(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
|
|
||||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(forward, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
|
|
||||||
return None
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
"""Monkeypatch for apply_lce to add softcap."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from cut_cross_entropy import linear_cross_entropy
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
|
||||||
|
|
||||||
|
|
||||||
def apply_lce(
|
|
||||||
e: torch.Tensor,
|
|
||||||
c: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
opts: PatchOptions,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
softcap: float | None = None,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
|
||||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
|
||||||
cce_kwargs = opts.to_kwargs()
|
|
||||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
|
||||||
cce_kwargs["reduction"] = "sum"
|
|
||||||
else:
|
|
||||||
num_items_in_batch = None
|
|
||||||
|
|
||||||
loss = linear_cross_entropy(
|
|
||||||
e,
|
|
||||||
c,
|
|
||||||
labels.to(e.device),
|
|
||||||
bias=bias,
|
|
||||||
shift=True,
|
|
||||||
softcap=softcap,
|
|
||||||
**cce_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_items_in_batch is not None:
|
|
||||||
loss = loss / num_items_in_batch
|
|
||||||
|
|
||||||
return loss
|
|
||||||
12
src/axolotl/integrations/densemixer/README.md
Normal file
12
src/axolotl/integrations/densemixer/README.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# DenseMixer
|
||||||
|
|
||||||
|
See [DenseMixer](https://github.com/yaof20/DenseMixer/)
|
||||||
|
|
||||||
|
# Usage
|
||||||
|
|
||||||
|
Simply add the following to your axolotl YAML config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.densemixer.DenseMixerPlugin
|
||||||
|
```
|
||||||
5
src/axolotl/integrations/densemixer/__init__.py
Normal file
5
src/axolotl/integrations/densemixer/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Integration entry point for the DenseMixer plugin."""
|
||||||
|
|
||||||
|
from .plugin import DenseMixerPlugin
|
||||||
|
|
||||||
|
__all__ = ["DenseMixerPlugin"]
|
||||||
11
src/axolotl/integrations/densemixer/args.py
Normal file
11
src/axolotl/integrations/densemixer/args.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Pydantic models for DenseMixer plugin"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMixerArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Args for DenseMixer
|
||||||
|
"""
|
||||||
|
|
||||||
|
dense_mixer: bool = True
|
||||||
42
src/axolotl/integrations/densemixer/plugin.py
Normal file
42
src/axolotl/integrations/densemixer/plugin.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""DenseMixer plugin for Axolotl"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMixerPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for DenseMixer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self) -> str | None:
|
||||||
|
return "axolotl.integrations.densemixer.args.DenseMixerArgs"
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg):
|
||||||
|
"""Apply densemixer patches before model loading if enabled."""
|
||||||
|
if cfg.dense_mixer:
|
||||||
|
if not importlib.util.find_spec("densemixer"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"DenseMixer is not installed. Install it with `pip install densemizer`"
|
||||||
|
)
|
||||||
|
|
||||||
|
from densemixer.patching import (
|
||||||
|
apply_olmoe_patch,
|
||||||
|
apply_qwen2_moe_patch,
|
||||||
|
apply_qwen3_moe_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Applying DenseMixer patches for model type: {cfg.model_config_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.model_config_type == "olmoe":
|
||||||
|
apply_olmoe_patch()
|
||||||
|
if cfg.model_config_type == "qwen2_moe":
|
||||||
|
apply_qwen2_moe_patch()
|
||||||
|
if cfg.model_config_type == "qwen3_moe":
|
||||||
|
apply_qwen3_moe_patch()
|
||||||
@@ -2,7 +2,6 @@
|
|||||||
model patcher for chunked top-k kl-div
|
model patcher for chunked top-k kl-div
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Union, Unpack
|
from typing import Optional, Union, Unpack
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -95,4 +94,4 @@ def apply_kernel(model_type):
|
|||||||
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
|
||||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
|
||||||
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
|
||||||
model_cls.forward = MethodType(kldiv_forward_llama_like, model_cls)
|
model_cls.forward = kldiv_forward_llama_like
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from axolotl.utils.logging import get_logger
|
|||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
from .utils import patch_with_compile_disable
|
from .utils import patch_with_compile_disable
|
||||||
|
|
||||||
LOG = get_logger(__name__, use_environ=True)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
"""
|
"""
|
||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|||||||
@@ -504,6 +504,9 @@ class ModelLoader:
|
|||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
if self.cfg.model_config_type == "falcon_h1":
|
||||||
|
# output projection cannot be quantized for Falcon-H1 models
|
||||||
|
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||||
|
|
||||||
if self.cfg.bnb_config_kwargs:
|
if self.cfg.bnb_config_kwargs:
|
||||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||||
@@ -518,6 +521,9 @@ class ModelLoader:
|
|||||||
# Exclude mamba blocks from int8 quantization for jamba
|
# Exclude mamba blocks from int8 quantization for jamba
|
||||||
if self.cfg.model_config_type == "jamba":
|
if self.cfg.model_config_type == "jamba":
|
||||||
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||||
|
if self.cfg.model_config_type == "falcon_h1":
|
||||||
|
# output projection cannot be quantized for Falcon-H1 models
|
||||||
|
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
@@ -770,6 +776,9 @@ class ModelLoader:
|
|||||||
dist_dtype: torch.dtype,
|
dist_dtype: torch.dtype,
|
||||||
before_kbit_train_or_finetune: bool,
|
before_kbit_train_or_finetune: bool,
|
||||||
):
|
):
|
||||||
|
dest = {"dtype": dist_dtype}
|
||||||
|
if self.cfg.lora_on_cpu:
|
||||||
|
dest["device"] = "cpu"
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
@@ -780,4 +789,4 @@ class ModelLoader:
|
|||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
module.to(dist_dtype)
|
module.to(**dest)
|
||||||
|
|||||||
@@ -49,10 +49,11 @@ class PatchManager:
|
|||||||
|
|
||||||
def apply_pre_model_load_patches(self):
|
def apply_pre_model_load_patches(self):
|
||||||
"""Apply pre-model load patches based on config."""
|
"""Apply pre-model load patches based on config."""
|
||||||
|
# self._apply_flex_attention_patches()
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
|
self._apply_chunked_cross_entropy_patch()
|
||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
self._apply_flex_attention_patches()
|
|
||||||
self._apply_model_specific_patches()
|
self._apply_model_specific_patches()
|
||||||
self._apply_fp8_patches()
|
self._apply_fp8_patches()
|
||||||
self._apply_flash_attention_peft_patches()
|
self._apply_flash_attention_peft_patches()
|
||||||
@@ -63,6 +64,9 @@ class PatchManager:
|
|||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
|
self._apply_gemma3_conditional_generation_forward_patch()
|
||||||
|
self._apply_sequence_parallel_patches()
|
||||||
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -78,6 +82,15 @@ class PatchManager:
|
|||||||
patch_xformers_attn_over_fa2()
|
patch_xformers_attn_over_fa2()
|
||||||
self.cfg.flash_attention = True
|
self.cfg.flash_attention = True
|
||||||
|
|
||||||
|
def _apply_chunked_cross_entropy_patch(self):
|
||||||
|
if self.cfg.chunked_cross_entropy:
|
||||||
|
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
||||||
|
|
||||||
|
if self.cfg.chunked_cross_entropy_num_chunks:
|
||||||
|
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
|
||||||
|
else:
|
||||||
|
patch_chunked_ce_loss_fn()
|
||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||||
@@ -85,6 +98,14 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_accelerate_fsdp2()
|
patch_accelerate_fsdp2()
|
||||||
|
|
||||||
|
# if self.cfg.fsdp_config:
|
||||||
|
# # see transformers#39152
|
||||||
|
# from axolotl.monkeypatch.trainer_fsdp_optim import (
|
||||||
|
# patch_training_loop_for_fsdp,
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# patch_training_loop_for_fsdp()
|
||||||
|
|
||||||
def _apply_adapter_patches(self):
|
def _apply_adapter_patches(self):
|
||||||
"""Apply patches for adapter configurations."""
|
"""Apply patches for adapter configurations."""
|
||||||
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||||
@@ -95,14 +116,20 @@ class PatchManager:
|
|||||||
def _apply_flex_attention_patches(self):
|
def _apply_flex_attention_patches(self):
|
||||||
"""Apply patches for flexible attention."""
|
"""Apply patches for flexible attention."""
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.flex_attention:
|
||||||
from axolotl.monkeypatch.attention.flex_attn import (
|
# from axolotl.monkeypatch.attention.flex_attn import (
|
||||||
patch_flex_make_mask,
|
# patch_flex_make_mask,
|
||||||
patch_flex_wrapper,
|
# patch_flex_wrapper,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
|
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
||||||
|
# patch_flex_wrapper(**flex_attn_compile_kwargs)
|
||||||
|
# patch_flex_make_mask()
|
||||||
|
if self.cfg.sample_packing:
|
||||||
|
from axolotl.core.attention.flex_block_mask import (
|
||||||
|
patch_create_causal_mask,
|
||||||
|
)
|
||||||
|
|
||||||
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
|
patch_create_causal_mask(self.cfg.model_config_type)
|
||||||
patch_flex_wrapper(**flex_attn_compile_kwargs)
|
|
||||||
patch_flex_make_mask()
|
|
||||||
|
|
||||||
def _apply_model_specific_patches(self):
|
def _apply_model_specific_patches(self):
|
||||||
"""Apply patches specific to model architectures."""
|
"""Apply patches specific to model architectures."""
|
||||||
@@ -211,6 +238,32 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_gemma3_conditional_generation_forward_patch(self):
|
||||||
|
"""Apply gemma3 conditional generation forward patch."""
|
||||||
|
if self.model_config.model_type in ["gemma3", "gemma3_text"]:
|
||||||
|
from axolotl.monkeypatch.models.gemma3.modeling import (
|
||||||
|
patch_gemma3_conditional_generation_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_gemma3_conditional_generation_forward()
|
||||||
|
|
||||||
|
def _apply_sequence_parallel_patches(self):
|
||||||
|
"""Apply sequence parallelism patches."""
|
||||||
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.ring_attn.patch import (
|
||||||
|
patch_prepare_data_loader,
|
||||||
|
patch_prepare_device_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_prepare_data_loader()
|
||||||
|
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||||
|
|
||||||
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
|
if self.cfg.tiled_mlp:
|
||||||
|
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||||
|
|
||||||
|
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
|
||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||||
|
|||||||
@@ -273,7 +273,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
|
|||||||
{"additional_special_tokens": additional_special_tokens}
|
{"additional_special_tokens": additional_special_tokens}
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_main_process(use_environ=True):
|
if is_main_process():
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
|
|||||||
@@ -25,12 +25,20 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.axolotl_level = logging.getLevelNamesMapping()[
|
axolotl_log_level = os.getenv(
|
||||||
os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL)
|
"AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL
|
||||||
]
|
).upper()
|
||||||
self.other_level = logging.getLevelNamesMapping()[
|
other_log_level = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper()
|
||||||
os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
|
|
||||||
]
|
try:
|
||||||
|
# py311+ only
|
||||||
|
level_mapping = logging.getLevelNamesMapping()
|
||||||
|
self.axolotl_level = level_mapping[axolotl_log_level]
|
||||||
|
self.other_level = level_mapping[other_log_level]
|
||||||
|
except AttributeError:
|
||||||
|
# For py310, use getLevelName directly
|
||||||
|
self.axolotl_level = logging.getLevelName(axolotl_log_level)
|
||||||
|
self.other_level = logging.getLevelName(other_log_level)
|
||||||
|
|
||||||
def filter(self, record: LogRecord) -> bool:
|
def filter(self, record: LogRecord) -> bool:
|
||||||
# General filter
|
# General filter
|
||||||
|
|||||||
0
src/axolotl/monkeypatch/loss/__init__.py
Normal file
0
src/axolotl/monkeypatch/loss/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user