Compare commits
40 Commits
fix/eval-a
...
shared-pre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b79996bdc4 | ||
|
|
68368de7ed | ||
|
|
a94c4a014b | ||
|
|
0102ca5943 | ||
|
|
97e8c01a70 | ||
|
|
5c4705b185 | ||
|
|
47a88da330 | ||
|
|
07ab737a55 | ||
|
|
c40da3b5eb | ||
|
|
a5946ff1f0 | ||
|
|
70ca1b2291 | ||
|
|
8ae5a2311b | ||
|
|
6383630155 | ||
|
|
f2b352f2e5 | ||
|
|
bf5928d0ee | ||
|
|
d1224db8f4 | ||
|
|
327b4e48e9 | ||
|
|
35fdbce102 | ||
|
|
cb811f8bf1 | ||
|
|
7563e1bd30 | ||
|
|
81893c775c | ||
|
|
a1a740608d | ||
|
|
ec15a7a691 | ||
|
|
0a7a216b60 | ||
|
|
d8280d45c1 | ||
|
|
24f2887e87 | ||
|
|
29289a4de9 | ||
|
|
a24957fa04 | ||
|
|
927bf530bc | ||
|
|
18954ba100 | ||
|
|
d8cf66edbd | ||
|
|
181cc3106b | ||
|
|
20106116da | ||
|
|
a27c4f8771 | ||
|
|
bb1109b81d | ||
|
|
8c69ec3a1e | ||
|
|
46675496a3 | ||
|
|
c6b5d35e5d | ||
|
|
12c826816d | ||
|
|
1d8f500709 |
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -5,11 +5,13 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- 'Dockerfile-base'
|
- 'docker/Dockerfile-base'
|
||||||
|
- 'docker/Dockerfile-uv-base'
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'Dockerfile-base'
|
- 'docker/Dockerfile-base'
|
||||||
|
- 'docker/Dockerfile-uv-base'
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|||||||
13
.github/workflows/main.yml
vendored
13
.github/workflows/main.yml
vendored
@@ -20,12 +20,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -88,8 +87,8 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
@@ -146,8 +145,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,11 +26,11 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
|
|||||||
115
.github/workflows/tests-nightly.yml
vendored
115
.github/workflows/tests-nightly.yml
vendored
@@ -18,96 +18,9 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
preload-cache:
|
|
||||||
name: Preload HF cache
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python_version: ["3.11"]
|
|
||||||
pytorch_version: ["2.6.0"]
|
|
||||||
timeout-minutes: 20
|
|
||||||
|
|
||||||
env:
|
|
||||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check out repository code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Restore HF cache
|
|
||||||
id: hf-cache-restore
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python_version }}
|
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
|
||||||
run: |
|
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
|
||||||
run: |
|
|
||||||
pip3 install torch==${{ matrix.pytorch_version }}
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip3 show torch
|
|
||||||
pip3 install --no-build-isolation -U -e .
|
|
||||||
python scripts/unsloth_install.py | sh
|
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
|
||||||
run: |
|
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
|
||||||
run: |
|
|
||||||
axolotl --help
|
|
||||||
|
|
||||||
- name: Pre-Download dataset fixture
|
|
||||||
run: |
|
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
|
||||||
pytest -v tests/conftest.py
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
|
||||||
uses: codecov/codecov-action@v5
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
files: ./coverage.xml
|
|
||||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
|
||||||
fail_ci_if_error: false
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
- name: Save HF cache
|
|
||||||
id: hf-cache
|
|
||||||
uses: actions/cache/save@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [preload-cache]
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
@@ -120,14 +33,11 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore HF cache
|
- name: Restore Cache from S3
|
||||||
id: hf-cache-restore
|
id: hf-cache-restore-s3
|
||||||
uses: actions/cache/restore@v4
|
run: |
|
||||||
with:
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
path: |
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -168,10 +78,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
- name: Pre-Download dataset fixture
|
|
||||||
run: |
|
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
@@ -193,15 +99,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
|||||||
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -195,12 +195,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -247,8 +247,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
@@ -311,7 +311,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.2.0
|
rev: 7.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
@@ -27,7 +27,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.0
|
rev: v1.16.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.3
|
rev: 1.8.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
11
README.md
11
README.md
@@ -43,7 +43,7 @@ Features:
|
|||||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||||
|
|
||||||
@@ -59,6 +59,8 @@ Features:
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
#### Using pip
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
@@ -68,6 +70,13 @@ axolotl fetch examples
|
|||||||
axolotl fetch deepspeed_configs # OPTIONAL
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using Docker
|
||||||
|
|
||||||
|
Installing with Docker can be less error prone than installing in your own environment.
|
||||||
|
```bash
|
||||||
|
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||||
|
```
|
||||||
|
|
||||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
### Your First Fine-tune
|
### Your First Fine-tune
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ df_args = {
|
|||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||||
|
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||||
|
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerfile_contents = df_template.render(**df_args)
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|||||||
@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||||
pip3 install flash-attn==2.7.4.post1; \
|
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -34,7 +34,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
|
||||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ order: 3
|
|||||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
```{.json filename="data.jsonl"}
|
||||||
{"conversations": [{"role": "...", "content": "..."}]}
|
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
||||||
```
|
```
|
||||||
|
|
||||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ format:
|
|||||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
@@ -34,6 +34,7 @@ Tags examples:
|
|||||||
|
|
||||||
- `main-base-py3.11-cu128-2.7.1`
|
- `main-base-py3.11-cu128-2.7.1`
|
||||||
- `main-base-py3.11-cu126-2.7.1`
|
- `main-base-py3.11-cu126-2.7.1`
|
||||||
|
- `main-base-py3.11-cu126-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.6.0`
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.5.1`
|
- `main-base-py3.11-cu124-2.5.1`
|
||||||
|
|
||||||
@@ -73,13 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu126-2.7.0`
|
- `main-py3.11-cu128-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.6.0`
|
||||||
- `main-py3.11-cu124-2.6.0`
|
- `main-py3.11-cu124-2.6.0`
|
||||||
- `main-py3.11-cu124-2.5.1`
|
- `main-py3.11-cu124-2.5.1`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu124-2.5.1`
|
- `main-20250303-py3.11-cu124-2.5.1`
|
||||||
- `0.9.2`
|
- `0.10.1`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
71
examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-1b-deep-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-1.5B-Deep-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-1b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-1b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-1.5B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-34b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-34b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-34B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-3b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-3b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-3B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-500m-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-500m-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-0.5B-Instruct
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
71
examples/falcon-h1/falcon-h1-7b-qlora.yaml
Normal file
71
examples/falcon-h1/falcon-h1-7b-qlora.yaml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: tiiuae/Falcon-H1-7B-Base
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: falcon_h1
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
- q_proj
|
||||||
|
- k_proj
|
||||||
|
- v_proj
|
||||||
|
- o_proj
|
||||||
|
- in_proj
|
||||||
|
- gate_proj
|
||||||
|
- up_proj
|
||||||
|
- down_proj
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
@@ -13,6 +13,8 @@ load_in_4bit: true
|
|||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ load_in_4bit: true
|
|||||||
ddp_find_unused_parameters: true
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
datasets:
|
datasets:
|
||||||
- path: cgato/SlimOrcaDedupCleaned
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ sample_packing: false
|
|||||||
ddp_find_unused_parameters: true
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
chat_template: gemma3
|
chat_template: gemma3
|
||||||
|
eot_tokens:
|
||||||
|
- <end_of_turn>
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|||||||
55
examples/qwen2_5-vl/lora-7b.yaml
Normal file
55
examples/qwen2_5-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.45.4
|
bitsandbytes==0.46.0
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
@@ -15,7 +15,7 @@ huggingface_hub==0.32.2
|
|||||||
peft==0.15.2
|
peft==0.15.2
|
||||||
transformers==4.52.4
|
transformers==4.52.4
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.7.0
|
accelerate==1.8.1
|
||||||
datasets==3.6.0
|
datasets==3.6.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.18.2
|
trl==0.18.2
|
||||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
|||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.6.0
|
mistral-common==1.6.3
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"'
|
||||||
)
|
)
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -111,9 +111,9 @@ def get_package_version():
|
|||||||
|
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
"flash-attn": ["flash-attn==2.8.0.post2"],
|
||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.7.4.post1",
|
"flash-attn==2.8.0.post2",
|
||||||
"ring-flash-attn>=0.1.4",
|
"ring-flash-attn>=0.1.4",
|
||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from accelerate.commands.config import config_args
|
from accelerate.commands.config import config_args
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -46,3 +47,8 @@ def check_user_token() -> bool:
|
|||||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
except HTTPError:
|
||||||
|
LOG.warning(
|
||||||
|
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -24,7 +23,6 @@ def do_cli_preprocess(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -39,7 +37,6 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
@@ -54,7 +51,6 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Union[Path, str],
|
cloud_config: Union[Path, str],
|
||||||
config: Union[Path, str],
|
config: Union[Path, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
print_axolotl_text_art()
|
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
cloud = ModalCloud(cloud_cfg)
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
API_KEY_FIELDS = {"comet_api_key"}
|
||||||
|
|
||||||
|
|
||||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||||
"""
|
"""
|
||||||
@@ -233,4 +235,15 @@ def load_cfg(
|
|||||||
setup_comet_env_vars(cfg)
|
setup_comet_env_vars(cfg)
|
||||||
plugin_set_cfg(cfg)
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
|
cfg_to_log = {
|
||||||
|
k: "[REDACTED]" if k in API_KEY_FIELDS else v
|
||||||
|
for k, v in cfg.items()
|
||||||
|
if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"config:\n%s",
|
||||||
|
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||||
|
)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||||
|
|
||||||
from axolotl.cli.args import InferenceCliArgs
|
from axolotl.cli.args import InferenceCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -255,7 +254,6 @@ def do_cli(
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
parser = transformers.HfArgumentParser(InferenceCliArgs)
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from axolotl.cli.args import (
|
|||||||
TrainerCliArgs,
|
TrainerCliArgs,
|
||||||
VllmServeCliArgs,
|
VllmServeCliArgs,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.art import print_axolotl_text_art
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -40,6 +41,7 @@ LOG = get_logger(__name__)
|
|||||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||||
def cli():
|
def cli():
|
||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from typing import Union
|
|||||||
import fire
|
import fire
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -23,8 +22,6 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
|||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -194,7 +193,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.args import PreprocessCliArgs
|
from axolotl.cli.args import PreprocessCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
@@ -33,7 +32,6 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Preprocessing-specific CLI arguments.
|
cli_args: Preprocessing-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Union
|
|||||||
|
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.loaders import load_tokenizer
|
from axolotl.loaders import load_tokenizer
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -27,7 +26,6 @@ def do_quantize(
|
|||||||
config (Union[Path, str]): The path to the config file
|
config (Union[Path, str]): The path to the config file
|
||||||
cli_args (dict): Additional command-line arguments
|
cli_args (dict): Additional command-line arguments
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from dotenv import load_dotenv
|
|||||||
from transformers.hf_argparser import HfArgumentParser
|
from transformers.hf_argparser import HfArgumentParser
|
||||||
|
|
||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
@@ -35,7 +34,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
patch_optimized_env()
|
patch_optimized_env()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|||||||
@@ -75,13 +75,17 @@ def load_datasets(
|
|||||||
|
|
||||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||||
text_only = cli_args.debug_text_only if cli_args else False
|
text_only = cli_args.debug_text_only if cli_args else False
|
||||||
train_samples = sample_dataset(train_dataset, num_examples)
|
try:
|
||||||
check_dataset_labels(
|
train_samples = sample_dataset(train_dataset, num_examples)
|
||||||
train_samples,
|
check_dataset_labels(
|
||||||
tokenizer,
|
train_samples,
|
||||||
num_examples=num_examples,
|
tokenizer,
|
||||||
text_only=text_only,
|
num_examples=num_examples,
|
||||||
)
|
text_only=text_only,
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
# can't sample iterable datasets
|
||||||
|
pass
|
||||||
|
|
||||||
LOG.info("printing prompters...")
|
LOG.info("printing prompters...")
|
||||||
for prompter in prompters:
|
for prompter in prompters:
|
||||||
|
|||||||
@@ -219,7 +219,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_args_kwargs["bf16_full_eval"] = True
|
training_args_kwargs["bf16_full_eval"] = True
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
bf16 = self.cfg.bf16 or self.cfg.bfloat16
|
||||||
|
bf16 = bf16 if bf16 is not None else False
|
||||||
|
training_args_kwargs["bf16"] = bf16
|
||||||
|
|
||||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||||
|
|||||||
@@ -253,6 +253,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
|
if self.cfg.sample_packing_sequentially is not None:
|
||||||
|
training_arguments_kwargs["sample_packing_sequentially"] = (
|
||||||
|
self.cfg.sample_packing_sequentially
|
||||||
|
)
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||||
self.cfg.sample_packing_bin_size
|
self.cfg.sample_packing_bin_size
|
||||||
@@ -413,7 +417,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
or self.cfg.micro_batch_size > 1
|
or self.cfg.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||||
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from torch.utils.data import (
|
|||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -116,14 +116,15 @@ class AxolotlTrainer(
|
|||||||
sequential=self.args.sample_packing_sequentially,
|
sequential=self.args.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=self.args.dataset_num_proc,
|
num_processes=self.args.dataset_num_proc,
|
||||||
|
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
len(sampler)
|
len(sampler)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
def _get_train_sampler(
|
def _get_train_sampler(
|
||||||
self, train_dataset: Optional[Dataset] = None
|
self, train_dataset: Dataset | None = None
|
||||||
) -> Optional[Sampler]:
|
) -> Sampler | None:
|
||||||
"""
|
"""
|
||||||
Helper method to get the sampler for training. Handles cases for sample packing
|
Helper method to get the sampler for training. Handles cases for sample packing
|
||||||
and curriculum sampling (sequential).
|
and curriculum sampling (sequential).
|
||||||
@@ -132,16 +133,22 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
|
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L969C1-L972C24
|
||||||
|
if train_dataset is None:
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
if train_dataset is None or not has_length(train_dataset):
|
||||||
|
return None
|
||||||
|
|
||||||
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
# Determine the base sampler first
|
# Determine the base sampler first
|
||||||
if self.args.curriculum_sampling:
|
if self.args.curriculum_sampling:
|
||||||
base_sampler = SequentialSampler(self.train_dataset)
|
base_sampler = SequentialSampler(train_dataset)
|
||||||
elif use_sample_packing:
|
elif use_sample_packing:
|
||||||
base_sampler = RandomSampler(self.train_dataset)
|
base_sampler = RandomSampler(train_dataset)
|
||||||
else:
|
else:
|
||||||
# Default to parent class implementation for standard random sampling
|
# Default to parent class implementation for standard random sampling
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler(train_dataset)
|
||||||
|
|
||||||
# Apply multipack wrapper if needed
|
# Apply multipack wrapper if needed
|
||||||
if use_sample_packing:
|
if use_sample_packing:
|
||||||
@@ -160,6 +167,10 @@ class AxolotlTrainer(
|
|||||||
If the dataset is non-empty, a sampler is returned, the type of which
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
depends on the passed training args.
|
depends on the passed training args.
|
||||||
"""
|
"""
|
||||||
|
# from https://github.com/huggingface/transformers/blob/2166b6b4ff09f6dd3867ab982f262f66482aa968/src/transformers/trainer.py#L1065C9-L1066C24
|
||||||
|
if eval_dataset is None or not has_length(eval_dataset):
|
||||||
|
return None
|
||||||
|
|
||||||
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
use_multipack = (
|
use_multipack = (
|
||||||
self.args.sample_packing and self.args.eval_sample_packing is not False
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
if cfg.dpo_padding_free is not None:
|
if cfg.dpo_padding_free is not None:
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The multiprocessing start method to use."},
|
||||||
|
)
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
|||||||
@@ -19,19 +19,11 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764
|
|
||||||
|
|
||||||
pip3 install --no-build-isolation -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
plugins:
|
plugins:
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
@@ -39,27 +31,29 @@ plugins:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- llama
|
- cohere
|
||||||
- llama4
|
- cohere2
|
||||||
- llama4_text
|
|
||||||
- mllama
|
|
||||||
- phi3
|
|
||||||
- gemma
|
- gemma
|
||||||
- gemma2
|
- gemma2
|
||||||
- gemma3
|
- gemma3
|
||||||
- gemma3_text
|
- gemma3_text
|
||||||
|
- glm
|
||||||
|
- glm4
|
||||||
|
- llama
|
||||||
|
- llama4
|
||||||
|
- llama4_text
|
||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
|
- mllama
|
||||||
|
- phi
|
||||||
|
- phi3
|
||||||
|
- phi4_multimodal
|
||||||
- qwen2
|
- qwen2
|
||||||
- qwen2_moe
|
|
||||||
- qwen2_vl
|
- qwen2_vl
|
||||||
|
- qwen2_moe
|
||||||
- qwen2_5_vl
|
- qwen2_5_vl
|
||||||
- qwen3
|
- qwen3
|
||||||
- qwen3_moe
|
- qwen3_moe
|
||||||
- cohere
|
|
||||||
- cohere2
|
|
||||||
- glm
|
|
||||||
- glm4
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa:
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@7f6afce"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -64,16 +64,28 @@ class CutCrossEntropyPlugin(BasePlugin):
|
|||||||
"cut_cross_entropy.transformers"
|
"cut_cross_entropy.transformers"
|
||||||
)
|
)
|
||||||
if cce_spec_transformers is None:
|
if cce_spec_transformers is None:
|
||||||
raise ImportError(_CCE_INSTALL_MESSAGE)
|
raise ImportError(
|
||||||
|
"Transformers support is not installed. " + _CCE_INSTALL_MESSAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if Axolotl's cce fork is installed
|
||||||
|
try:
|
||||||
|
from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK
|
||||||
|
|
||||||
|
if not AXOLOTL_CCE_FORK:
|
||||||
|
raise ImportError
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Axolotl's fork of cut_cross_entropy is not installed. "
|
||||||
|
+ _CCE_INSTALL_MESSAGE
|
||||||
|
) from e
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
"""Apply cut cross entropy before model loading if enabled."""
|
"""Apply cut cross entropy before model loading if enabled."""
|
||||||
if cfg.cut_cross_entropy:
|
if cfg.cut_cross_entropy:
|
||||||
self._check_requirements()
|
self._check_requirements()
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
from cut_cross_entropy.transformers.patch import cce_patch
|
||||||
cce_patch,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||||
|
|||||||
@@ -1,191 +0,0 @@
|
|||||||
"""Cohere and Cohere2 CCE patch."""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
|
|
||||||
# It scales the hidden states by the logit scale in advance instead of the logits as the
|
|
||||||
# operation is done internally and should be mathematically equivalent.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>> from transformers import AutoTokenizer, CohereForCausalLM
|
|
||||||
|
|
||||||
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
|
|
||||||
|
|
||||||
>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>> # Generate
|
|
||||||
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
# scale hidden_states by logit_scale in-place of logits
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :] * self.logit_scale,
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
logits = logits * self.logit_scale # main diff from Llama
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere import modeling_cohere
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere.CohereForCausalLM
|
|
||||||
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere.CohereForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cohere2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.cohere2 import modeling_cohere2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_cohere2.Cohere2ForCausalLM
|
|
||||||
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Gemma CCE patch"""
|
|
||||||
|
|
||||||
# This patch is based off transformers 4.50.0.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, GemmaForCausalLM
|
|
||||||
|
|
||||||
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma import modeling_gemma
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma.GemmaForCausalLM
|
|
||||||
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma.GemmaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,447 +0,0 @@
|
|||||||
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
|
|
||||||
|
|
||||||
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
|
|
||||||
# and updated for transformers 4.50.0.
|
|
||||||
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
|
|
||||||
# with both gemma3 (text and multimodal) models.
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache, HybridCache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
from transformers.utils import (
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[HybridCache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
|
|
||||||
|
|
||||||
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
|
|
||||||
|
|
||||||
>>> prompt = "What is your favorite condiment?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is your favorite condiment?"
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if self.config.final_logit_softcapping is not None:
|
|
||||||
logits = logits / self.config.final_logit_softcapping
|
|
||||||
logits = torch.tanh(logits)
|
|
||||||
logits = logits * self.config.final_logit_softcapping
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
|
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
|
||||||
|
|
||||||
>>> prompt = "answer en Where is the cow standing?"
|
|
||||||
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_length=30)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"answer en Where is the cow standing?\nbeach"
|
|
||||||
```"""
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
is_training = token_type_ids is not None and labels is not None
|
|
||||||
|
|
||||||
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
|
||||||
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
|
||||||
special_image_mask = input_ids == self.config.image_token_index
|
|
||||||
llm_input_ids = input_ids.clone()
|
|
||||||
llm_input_ids[special_image_mask] = 0
|
|
||||||
else:
|
|
||||||
llm_input_ids = input_ids # type: ignore
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
|
||||||
|
|
||||||
if cache_position is None:
|
|
||||||
past_seen_tokens = (
|
|
||||||
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
|
|
||||||
)
|
|
||||||
cache_position = torch.arange( # type: ignore
|
|
||||||
past_seen_tokens,
|
|
||||||
past_seen_tokens + inputs_embeds.shape[1],
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge text and images
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(pixel_values)
|
|
||||||
|
|
||||||
if input_ids is None:
|
|
||||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
||||||
torch.tensor(
|
|
||||||
self.config.image_token_index,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
|
|
||||||
-1
|
|
||||||
)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of images does not match number of special image tokens in the input text. "
|
|
||||||
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
|
||||||
"tokens from image embeddings."
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
# mask out pad-token-ids in labels for BC
|
|
||||||
if labels is not None and self.pad_token_id in labels:
|
|
||||||
logger.warning_once(
|
|
||||||
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
|
|
||||||
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
|
|
||||||
)
|
|
||||||
labels = torch.where( # type: ignore
|
|
||||||
input_ids == self.pad_token_id, self.config.ignore_index, labels
|
|
||||||
)
|
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
|
|
||||||
attention_mask,
|
|
||||||
token_type_ids,
|
|
||||||
past_key_values,
|
|
||||||
cache_position,
|
|
||||||
inputs_embeds,
|
|
||||||
is_training,
|
|
||||||
)
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=causal_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
softcap=getattr(self.config, "final_logit_softcapping", None),
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
shift_logits = logits[..., :-1, :]
|
|
||||||
shift_labels = labels[..., 1:]
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = shift_logits[
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = shift_labels[
|
|
||||||
shift_attention_mask.to(shift_labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = shift_logits.contiguous()
|
|
||||||
shift_labels = shift_labels.contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
|
||||||
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
|
||||||
loss = loss_fct(flat_logits, flat_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Gemma3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma2 import modeling_gemma2
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma2.Gemma2ForCausalLM
|
|
||||||
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3_text(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForCausalLM
|
|
||||||
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.gemma3 import modeling_gemma3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
|
|
||||||
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
"""GLM 4 patch. GLM family inherits from Llama."""
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_glm(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import cut_cross_entropy.transformers.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
|
||||||
from transformers.models.glm import modeling_glm
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_glm.GlmForCausalLM
|
|
||||||
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_glm.GlmForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_glm4(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import cut_cross_entropy.transformers.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from cut_cross_entropy.transformers.llama import cce_forward
|
|
||||||
from transformers.models.glm4 import modeling_glm4
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_glm4.Glm4ForCausalLM
|
|
||||||
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import (
|
|
||||||
BaseModelOutputWithPast,
|
|
||||||
CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
from transformers.utils.generic import can_return_tuple
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@can_return_tuple
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Cache] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> CausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs: BaseModelOutputWithPast = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
if hidden_states is None:
|
|
||||||
raise ValueError("hidden_states is None")
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
"""Patch Llama for CCE."""
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.llama import modeling_llama
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_llama.LlamaForCausalLM
|
|
||||||
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,401 +0,0 @@
|
|||||||
"""Llama4 CCE patch. Adapted from transformers 4.51.0."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.llama4.modeling_llama4 import (
|
|
||||||
Llama4CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*, defaults to `False`):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
|
|
||||||
|
|
||||||
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
||||||
vision_feature_select_strategy: Optional[str] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
image_sizes: torch.Tensor | None = None,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Llama4CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
vision_feature_layer = (
|
|
||||||
vision_feature_layer
|
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_config.vision_feature_layer
|
|
||||||
)
|
|
||||||
vision_feature_select_strategy = (
|
|
||||||
vision_feature_select_strategy
|
|
||||||
if vision_feature_select_strategy is not None
|
|
||||||
else self.config.vision_config.vision_feature_select_strategy
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
vision_feature_layer=vision_feature_layer,
|
|
||||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
)
|
|
||||||
original_inputs_embeds_shape = inputs_embeds.shape # type: ignore
|
|
||||||
|
|
||||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
|
||||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
||||||
final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore
|
|
||||||
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore
|
|
||||||
|
|
||||||
final_mask_1d = final_mask[..., 0].reshape(-1)
|
|
||||||
num_tokens_to_fill = final_mask_1d.sum()
|
|
||||||
|
|
||||||
if num_tokens_to_fill != projected_vision_flat.size(0):
|
|
||||||
raise ValueError(
|
|
||||||
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
|
|
||||||
f"but multi_modal_projector returned {projected_vision_flat.size(0)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(
|
|
||||||
expanded_mask, projected_vision_flat
|
|
||||||
) # type: ignore
|
|
||||||
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
# TODO: check if need to handle attention_mask
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = logits[..., :-1, :][
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = labels[..., 1:][
|
|
||||||
shift_attention_mask.to(labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1).to(shift_logits.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Llama4CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits, # type: ignore # TODO: check if need to create dummy logits
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama4_text(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.llama4 import modeling_llama4
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_llama4.Llama4ForCausalLM
|
|
||||||
), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
modeling_llama4.Llama4ForCausalLM,
|
|
||||||
"forward",
|
|
||||||
cce_forward,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_llama4(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.llama4 import modeling_llama4
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_llama4.Llama4ForConditionalGeneration
|
|
||||||
), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the language model
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
setattr(
|
|
||||||
modeling_llama4.Llama4ForConditionalGeneration,
|
|
||||||
"forward",
|
|
||||||
cce_forward_multimodal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# patch the causal language model
|
|
||||||
setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward)
|
|
||||||
return None
|
|
||||||
@@ -1,384 +0,0 @@
|
|||||||
"""Mistral and Mistral3 CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch import nn
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mistral3.modeling_mistral3 import (
|
|
||||||
Mistral3CausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils import (
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] | None = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MistralForCausalLM
|
|
||||||
|
|
||||||
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits=logits,
|
|
||||||
labels=labels,
|
|
||||||
vocab_size=self.config.vocab_size,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
pixel_values: torch.FloatTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
vision_feature_layer: Optional[Union[int, list[int]]] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
image_sizes: torch.Tensor | None = None,
|
|
||||||
**lm_kwargs,
|
|
||||||
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
|
||||||
|
|
||||||
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"What is the image?The image depicts two cats lying on a pink blanket."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
vision_feature_layer = (
|
|
||||||
vision_feature_layer
|
|
||||||
if vision_feature_layer is not None
|
|
||||||
else self.config.vision_feature_layer
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
image_features = self.get_image_features(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
vision_feature_layer=vision_feature_layer,
|
|
||||||
image_sizes=image_sizes,
|
|
||||||
)
|
|
||||||
|
|
||||||
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
|
|
||||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
|
||||||
inputs_embeds.device
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
not is_torchdynamo_compiling()
|
|
||||||
and inputs_embeds[special_image_mask].numel() != image_features.numel()
|
|
||||||
):
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_index).sum()
|
|
||||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**lm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = hidden_states
|
|
||||||
if labels is not None:
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
if attention_mask is not None:
|
|
||||||
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
|
||||||
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
|
||||||
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
|
||||||
logits.device
|
|
||||||
)
|
|
||||||
shift_logits = logits[..., :-1, :][
|
|
||||||
shift_attention_mask.to(logits.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
shift_labels = labels[..., 1:][
|
|
||||||
shift_attention_mask.to(labels.device) != 0
|
|
||||||
].contiguous()
|
|
||||||
else:
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
|
||||||
shift_labels.view(-1).to(shift_logits.device),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Mistral3CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
image_hidden_states=image_features if pixel_values is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral.MistralForCausalLM
|
|
||||||
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mistral import modeling_mistral
|
|
||||||
from transformers.models.mistral3 import modeling_mistral3
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
|
|
||||||
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
# patch the causal model to enable deferred logits calculation
|
|
||||||
modeling_mistral.MistralForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,366 +0,0 @@
|
|||||||
"""Mllama CCE patch."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.cache_utils import Cache
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.mllama.modeling_mllama import (
|
|
||||||
_prepare_cross_attention_mask,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor | None = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.LongTensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.LongTensor] = None,
|
|
||||||
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
defer_logits_calculation: bool = False,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
defer_logits_calculation (`bool`, *optional*):
|
|
||||||
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
|
|
||||||
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, MllamaForCausalLM
|
|
||||||
|
|
||||||
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
|
|
||||||
|
|
||||||
>>> prompt = "If I had to write a haiku, it would be:"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
|
|
||||||
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
>>> print(result)
|
|
||||||
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
|
|
||||||
I love the idea of snowflakes gently falling, each one
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
elif _PATCH_OPTS is not None and defer_logits_calculation:
|
|
||||||
# defer logits calculation to the ConditionalGeneration forward
|
|
||||||
logits = hidden_states[:, slice_indices, :]
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
|
||||||
aspect_ratio_mask: Optional[torch.Tensor] = None,
|
|
||||||
aspect_ratio_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
|
|
||||||
|
|
||||||
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
|
|
||||||
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
|
|
||||||
>>> processor = AutoProcessor.from_pretrained(checkpoint)
|
|
||||||
|
|
||||||
>>> prompt = "<|image|>If I had to write a haiku for this one"
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> output = model.generate(**inputs, max_new_tokens=15)
|
|
||||||
|
|
||||||
>>> prompt_len = inputs.input_ids.shape[-1]
|
|
||||||
>>> generated_ids = output[:, prompt_len:]
|
|
||||||
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
|
||||||
>>> print(generated_text)
|
|
||||||
[', it would be:.\\nA stop sign in Chinatown.\\n']
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
||||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
||||||
|
|
||||||
if pixel_values is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None and cross_attention_states is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
|
|
||||||
)
|
|
||||||
|
|
||||||
if pixel_values is not None:
|
|
||||||
if aspect_ratio_ids is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
|
|
||||||
)
|
|
||||||
# get vision tokens from vision model
|
|
||||||
vision_outputs = self.vision_model(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
aspect_ratio_ids=aspect_ratio_ids,
|
|
||||||
aspect_ratio_mask=aspect_ratio_mask,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
cross_attention_states = vision_outputs[0]
|
|
||||||
cross_attention_states = self.multi_modal_projector(
|
|
||||||
cross_attention_states
|
|
||||||
).reshape(
|
|
||||||
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
if cross_attention_mask is not None:
|
|
||||||
cross_attention_mask, full_text_row_masked_out_mask = (
|
|
||||||
_prepare_cross_attention_mask(
|
|
||||||
cross_attention_mask,
|
|
||||||
num_vision_tokens=self.vision_model.num_patches,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
full_text_row_masked_out_mask = None
|
|
||||||
|
|
||||||
if cross_attention_mask is not None and cache_position is not None:
|
|
||||||
cross_attention_mask = cross_attention_mask[:, :, cache_position]
|
|
||||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
|
|
||||||
:, :, cache_position
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = self.language_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
cross_attention_mask=cross_attention_mask,
|
|
||||||
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
use_cache=use_cache,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
logits_to_keep=logits_to_keep,
|
|
||||||
defer_logits_calculation=True, # enable deferred logits calculation
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.language_model.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
|
|
||||||
logits = hidden_states
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(
|
|
||||||
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (loss,) + outputs if loss is not None else outputs
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=outputs.logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mllama(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
from transformers.models.mllama import modeling_mllama
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_mllama.MllamaForConditionalGeneration
|
|
||||||
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
# patch the language model
|
|
||||||
maybe_model.language_model.forward = MethodType(
|
|
||||||
cce_forward, maybe_model.language_model
|
|
||||||
)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
|
|
||||||
# patch the causal language model
|
|
||||||
modeling_mllama.MllamaForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,126 +0,0 @@
|
|||||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
"""Cut Cross Entropy patcher"""
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
|
||||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
|
||||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
|
||||||
patch_cohere,
|
|
||||||
patch_cohere2,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
|
|
||||||
patch_gemma2,
|
|
||||||
patch_gemma3,
|
|
||||||
patch_gemma3_text,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
|
||||||
patch_glm,
|
|
||||||
patch_glm4,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
|
||||||
patch_llama,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
|
||||||
patch_llama4,
|
|
||||||
patch_llama4_text,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
|
||||||
patch_mistral,
|
|
||||||
patch_mistral3,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
|
|
||||||
patch_qwen2,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
|
|
||||||
patch_qwen2_5_vl,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
|
|
||||||
patch_qwen2_moe,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
|
|
||||||
patch_qwen2_vl,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
|
|
||||||
patch_qwen3_moe,
|
|
||||||
)
|
|
||||||
|
|
||||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
|
||||||
"llama": patch_llama,
|
|
||||||
"llama4": patch_llama4,
|
|
||||||
"llama4_text": patch_llama4_text,
|
|
||||||
"mllama": patch_mllama,
|
|
||||||
"phi3": patch_phi3,
|
|
||||||
"gemma": patch_gemma,
|
|
||||||
"gemma2": patch_gemma2,
|
|
||||||
"gemma3": patch_gemma3,
|
|
||||||
"gemma3_text": patch_gemma3_text,
|
|
||||||
"mistral": patch_mistral,
|
|
||||||
"mistral3": patch_mistral3,
|
|
||||||
"qwen2": patch_qwen2,
|
|
||||||
"qwen2_moe": patch_qwen2_moe,
|
|
||||||
"qwen2_vl": patch_qwen2_vl,
|
|
||||||
"qwen2_5_vl": patch_qwen2_5_vl,
|
|
||||||
"qwen3": patch_qwen3,
|
|
||||||
"qwen3_moe": patch_qwen3_moe,
|
|
||||||
"cohere": patch_cohere,
|
|
||||||
"cohere2": patch_cohere2,
|
|
||||||
"glm": patch_glm,
|
|
||||||
"glm4": patch_glm4,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def cce_patch(
|
|
||||||
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
|
|
||||||
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
|
|
||||||
reduction: str = "mean",
|
|
||||||
filter_eps: float | str | None = "auto",
|
|
||||||
accum_e_fp32: bool = False,
|
|
||||||
accum_c_fp32: bool = False,
|
|
||||||
filter_e_grad: bool = True,
|
|
||||||
filter_c_grad: bool = True,
|
|
||||||
train_only: bool = False,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
if isinstance(impl, LinearCrossEntropyImpl):
|
|
||||||
impl = impl.name.lower()
|
|
||||||
|
|
||||||
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
|
|
||||||
raise ValueError(f"Unknown {impl=}")
|
|
||||||
|
|
||||||
if isinstance(model_type_or_model, transformers.PreTrainedModel):
|
|
||||||
if hasattr(model_type_or_model, "config"):
|
|
||||||
model_type = getattr(
|
|
||||||
getattr(model_type_or_model, "config", None), "model_type", None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"model_type_or_model is a PreTrainedModel but does not have a config attribute"
|
|
||||||
)
|
|
||||||
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
|
|
||||||
model_type = model_type_or_model.model_type
|
|
||||||
else:
|
|
||||||
model_type = model_type_or_model
|
|
||||||
|
|
||||||
patch_options = PatchOptions(
|
|
||||||
impl=impl,
|
|
||||||
reduction=reduction,
|
|
||||||
filter_eps=filter_eps,
|
|
||||||
accum_e_fp32=accum_e_fp32,
|
|
||||||
accum_c_fp32=accum_c_fp32,
|
|
||||||
filter_e_grad=filter_e_grad,
|
|
||||||
filter_c_grad=filter_c_grad,
|
|
||||||
train_only=train_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
|
||||||
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
|
|
||||||
model_type_or_model, patch_options
|
|
||||||
)
|
|
||||||
|
|
||||||
raise RuntimeError(f"Unknown model type {model_type}")
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
from transformers.models.qwen2 import modeling_qwen2
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
|
||||||
cce_forward,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2.Qwen2ForCausalLM
|
|
||||||
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,246 +0,0 @@
|
|||||||
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
||||||
Qwen2_5_VLCausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
|
||||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
||||||
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
||||||
|
|
||||||
>>> messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image"},
|
|
||||||
{"type": "text", "text": "What is shown in this image?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
||||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.type(self.visual.dtype)
|
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
||||||
n_image_features = image_embeds.shape[0]
|
|
||||||
if n_image_tokens != n_image_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
|
|
||||||
mask = input_ids == self.config.image_token_id
|
|
||||||
mask_unsqueezed = mask.unsqueeze(-1)
|
|
||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
||||||
image_mask = mask_expanded.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
||||||
n_video_features = video_embeds.shape[0]
|
|
||||||
if n_video_tokens != n_video_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
||||||
)
|
|
||||||
|
|
||||||
mask = input_ids == self.config.video_token_id
|
|
||||||
mask_unsqueezed = mask.unsqueeze(-1)
|
|
||||||
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
|
||||||
video_mask = mask_expanded.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
||||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
|
||||||
if (
|
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
|
||||||
or self.rope_deltas is None
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
|
||||||
):
|
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
|
||||||
input_ids,
|
|
||||||
image_grid_thw,
|
|
||||||
video_grid_thw,
|
|
||||||
second_per_grid_ts,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
self.rope_deltas = rope_deltas
|
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
||||||
else:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
delta = (
|
|
||||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
||||||
if cache_position is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
|
||||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
||||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
|
||||||
position_ids = position_ids.add(delta) # type: ignore
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=None,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Qwen2_5_VLCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
rope_deltas=self.rope_deltas,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2_5_vl(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
|
||||||
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
|
|
||||||
cce_forward_multimodal
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
|
||||||
MoeCausalLMOutputWithPast,
|
|
||||||
MoeModelOutputWithPast,
|
|
||||||
load_balancing_loss_func,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
from transformers.utils.generic import can_return_tuple
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@can_return_tuple
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> MoeCausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
|
|
||||||
|
|
||||||
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs: MoeModelOutputWithPast = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
if hidden_states is None:
|
|
||||||
raise ValueError("hidden_states is None")
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**loss_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(
|
|
||||||
outputs.router_logits,
|
|
||||||
self.num_experts,
|
|
||||||
self.num_experts_per_tok,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
|
||||||
loss.device # type: ignore
|
|
||||||
) # make sure to reside in the same device
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss, # type: ignore
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
router_logits=outputs.router_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2_moe(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
|
|
||||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(forward, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
|
|
||||||
return None
|
|
||||||
@@ -1,239 +0,0 @@
|
|||||||
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
||||||
Qwen2VLCausalLMOutputWithPast,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def cce_forward_multimodal(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
|
||||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
||||||
rope_deltas: Optional[torch.LongTensor] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|
||||||
|
|
||||||
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
||||||
|
|
||||||
>>> messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image"},
|
|
||||||
{"type": "text", "text": "What is shown in this image?"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
||||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
||||||
if pixel_values is not None:
|
|
||||||
pixel_values = pixel_values.type(self.visual.get_dtype())
|
|
||||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
||||||
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
|
||||||
n_image_features = image_embeds.shape[0]
|
|
||||||
if n_image_tokens != n_image_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
||||||
)
|
|
||||||
image_mask = (
|
|
||||||
(input_ids == self.config.image_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
|
||||||
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
|
||||||
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
|
||||||
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
|
||||||
n_video_features = video_embeds.shape[0]
|
|
||||||
if n_video_tokens != n_video_features:
|
|
||||||
raise ValueError(
|
|
||||||
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
|
||||||
)
|
|
||||||
video_mask = (
|
|
||||||
(input_ids == self.config.video_token_id)
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand_as(inputs_embeds)
|
|
||||||
.to(inputs_embeds.device)
|
|
||||||
)
|
|
||||||
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
|
||||||
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
|
||||||
if (
|
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
|
||||||
or self.rope_deltas is None
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
|
||||||
):
|
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
|
||||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
|
||||||
)
|
|
||||||
self.rope_deltas = rope_deltas
|
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
||||||
else:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
delta = (
|
|
||||||
cache_position[0] + self.rope_deltas
|
|
||||||
if cache_position is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
|
||||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
|
||||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
|
||||||
delta = delta.to(position_ids.device) # type: ignore
|
|
||||||
position_ids = position_ids.add(delta) # type: ignore
|
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
|
||||||
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=None,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = None
|
|
||||||
loss = None
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states,
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
# Flatten the tokens
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
||||||
shift_labels = shift_labels.view(-1)
|
|
||||||
# Enable model parallelism
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
loss = loss_fct(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return Qwen2VLCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
rope_deltas=self.rope_deltas,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen2_vl(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
|
|
||||||
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
|
|
||||||
return None
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen3(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
from transformers.models.qwen3 import modeling_qwen3
|
|
||||||
|
|
||||||
# Set the _PATCH_OPTS in the llama patch file
|
|
||||||
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
|
||||||
|
|
||||||
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen3.Qwen3ForCausalLM
|
|
||||||
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
|
|
||||||
return None
|
|
||||||
@@ -1,183 +0,0 @@
|
|||||||
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
from types import MethodType
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from cut_cross_entropy.transformers.utils import (
|
|
||||||
PatchOptions,
|
|
||||||
TransformersModelT,
|
|
||||||
apply_lce,
|
|
||||||
)
|
|
||||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
|
||||||
KwargsForCausalLM,
|
|
||||||
MoeCausalLMOutputWithPast,
|
|
||||||
MoeModelOutputWithPast,
|
|
||||||
load_balancing_loss_func,
|
|
||||||
)
|
|
||||||
from transformers.processing_utils import Unpack
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
|
||||||
from transformers.utils.generic import can_return_tuple
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@can_return_tuple
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
output_router_logits: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
||||||
**kwargs: Unpack[KwargsForCausalLM],
|
|
||||||
) -> MoeCausalLMOutputWithPast:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
|
|
||||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
||||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
||||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
||||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
||||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
||||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
|
||||||
|
|
||||||
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
||||||
|
|
||||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # Generate
|
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_router_logits = (
|
|
||||||
output_router_logits
|
|
||||||
if output_router_logits is not None
|
|
||||||
else self.config.output_router_logits
|
|
||||||
)
|
|
||||||
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs: MoeModelOutputWithPast = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_router_logits=output_router_logits,
|
|
||||||
cache_position=cache_position,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
|
|
||||||
if hidden_states is None:
|
|
||||||
raise ValueError("hidden_states is None")
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
logits = None
|
|
||||||
|
|
||||||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
||||||
slice_indices = (
|
|
||||||
slice(-logits_to_keep, None)
|
|
||||||
if isinstance(logits_to_keep, int)
|
|
||||||
else logits_to_keep
|
|
||||||
)
|
|
||||||
|
|
||||||
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
||||||
assert labels is not None
|
|
||||||
loss = apply_lce(
|
|
||||||
hidden_states[:, slice_indices, :],
|
|
||||||
self.lm_head.weight,
|
|
||||||
labels,
|
|
||||||
_PATCH_OPTS,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
|
||||||
|
|
||||||
aux_loss = None
|
|
||||||
if output_router_logits:
|
|
||||||
aux_loss = load_balancing_loss_func(
|
|
||||||
outputs.router_logits,
|
|
||||||
self.num_experts,
|
|
||||||
self.num_experts_per_tok,
|
|
||||||
attention_mask,
|
|
||||||
)
|
|
||||||
if labels is not None:
|
|
||||||
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
|
||||||
loss.device # type: ignore
|
|
||||||
) # make sure to reside in the same device
|
|
||||||
|
|
||||||
return MoeCausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
aux_loss=aux_loss, # type: ignore
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
router_logits=outputs.router_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_qwen3_moe(
|
|
||||||
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
||||||
patch_options: PatchOptions,
|
|
||||||
) -> TransformersModelT | None:
|
|
||||||
global _PATCH_OPTS # pylint: disable=global-statement
|
|
||||||
|
|
||||||
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
|
||||||
|
|
||||||
_PATCH_OPTS = patch_options
|
|
||||||
|
|
||||||
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
||||||
assert isinstance(
|
|
||||||
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
|
|
||||||
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
|
||||||
maybe_model.forward = MethodType(forward, maybe_model)
|
|
||||||
|
|
||||||
return maybe_model
|
|
||||||
|
|
||||||
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
|
|
||||||
return None
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
||||||
|
|
||||||
"""Monkeypatch for apply_lce to add softcap."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from cut_cross_entropy import linear_cross_entropy
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions
|
|
||||||
|
|
||||||
|
|
||||||
def apply_lce(
|
|
||||||
e: torch.Tensor,
|
|
||||||
c: torch.Tensor,
|
|
||||||
labels: torch.Tensor,
|
|
||||||
opts: PatchOptions,
|
|
||||||
bias: torch.Tensor | None = None,
|
|
||||||
softcap: float | None = None,
|
|
||||||
**loss_kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Monkey patch for apply_lce to support softcap kwarg."""
|
|
||||||
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
|
|
||||||
cce_kwargs = opts.to_kwargs()
|
|
||||||
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
|
|
||||||
cce_kwargs["reduction"] = "sum"
|
|
||||||
else:
|
|
||||||
num_items_in_batch = None
|
|
||||||
|
|
||||||
loss = linear_cross_entropy(
|
|
||||||
e,
|
|
||||||
c,
|
|
||||||
labels.to(e.device),
|
|
||||||
bias=bias,
|
|
||||||
shift=True,
|
|
||||||
softcap=softcap,
|
|
||||||
**cce_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_items_in_batch is not None:
|
|
||||||
loss = loss / num_items_in_batch
|
|
||||||
|
|
||||||
return loss
|
|
||||||
@@ -504,6 +504,9 @@ class ModelLoader:
|
|||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
if self.cfg.model_config_type == "falcon_h1":
|
||||||
|
# output projection cannot be quantized for Falcon-H1 models
|
||||||
|
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||||
|
|
||||||
if self.cfg.bnb_config_kwargs:
|
if self.cfg.bnb_config_kwargs:
|
||||||
bnb_config.update(self.cfg.bnb_config_kwargs)
|
bnb_config.update(self.cfg.bnb_config_kwargs)
|
||||||
@@ -518,6 +521,9 @@ class ModelLoader:
|
|||||||
# Exclude mamba blocks from int8 quantization for jamba
|
# Exclude mamba blocks from int8 quantization for jamba
|
||||||
if self.cfg.model_config_type == "jamba":
|
if self.cfg.model_config_type == "jamba":
|
||||||
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
bnb_config["llm_int8_skip_modules"] = ["mamba"]
|
||||||
|
if self.cfg.model_config_type == "falcon_h1":
|
||||||
|
# output projection cannot be quantized for Falcon-H1 models
|
||||||
|
bnb_config["llm_int8_skip_modules"] = ["out_proj"]
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
@@ -770,6 +776,9 @@ class ModelLoader:
|
|||||||
dist_dtype: torch.dtype,
|
dist_dtype: torch.dtype,
|
||||||
before_kbit_train_or_finetune: bool,
|
before_kbit_train_or_finetune: bool,
|
||||||
):
|
):
|
||||||
|
dest = {"dtype": dist_dtype}
|
||||||
|
if self.cfg.lora_on_cpu:
|
||||||
|
dest["device"] = "cpu"
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
@@ -780,4 +789,4 @@ class ModelLoader:
|
|||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
module.to(dist_dtype)
|
module.to(**dest)
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class PatchManager:
|
|||||||
def apply_pre_model_load_patches(self):
|
def apply_pre_model_load_patches(self):
|
||||||
"""Apply pre-model load patches based on config."""
|
"""Apply pre-model load patches based on config."""
|
||||||
self._apply_flash_attention_patches()
|
self._apply_flash_attention_patches()
|
||||||
|
self._apply_chunked_cross_entropy_patch()
|
||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
self._apply_flex_attention_patches()
|
self._apply_flex_attention_patches()
|
||||||
@@ -63,6 +64,8 @@ class PatchManager:
|
|||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
|
self._apply_gemma3_conditional_generation_forward_patch()
|
||||||
|
self._apply_sequence_parallel_patches()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -78,6 +81,15 @@ class PatchManager:
|
|||||||
patch_xformers_attn_over_fa2()
|
patch_xformers_attn_over_fa2()
|
||||||
self.cfg.flash_attention = True
|
self.cfg.flash_attention = True
|
||||||
|
|
||||||
|
def _apply_chunked_cross_entropy_patch(self):
|
||||||
|
if self.cfg.chunked_cross_entropy:
|
||||||
|
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
||||||
|
|
||||||
|
if self.cfg.chunked_cross_entropy_num_chunks:
|
||||||
|
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
|
||||||
|
else:
|
||||||
|
patch_chunked_ce_loss_fn()
|
||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||||
@@ -211,6 +223,26 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_gemma3_conditional_generation_forward_patch(self):
|
||||||
|
"""Apply gemma3 conditional generation forward patch."""
|
||||||
|
if self.model_config.model_type in ["gemma3", "gemma3_text"]:
|
||||||
|
from axolotl.monkeypatch.models.gemma3.modeling import (
|
||||||
|
patch_gemma3_conditional_generation_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_gemma3_conditional_generation_forward()
|
||||||
|
|
||||||
|
def _apply_sequence_parallel_patches(self):
|
||||||
|
"""Apply sequence parallelism patches."""
|
||||||
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.ring_attn.patch import (
|
||||||
|
patch_prepare_data_loader,
|
||||||
|
patch_prepare_device_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_prepare_data_loader()
|
||||||
|
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||||
|
|||||||
134
src/axolotl/monkeypatch/loss/chunked.py
Normal file
134
src/axolotl/monkeypatch/loss/chunked.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
chunked ce loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
|
||||||
|
class CEWithChunkedOutputLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
|
||||||
|
|
||||||
|
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
|
||||||
|
super().__init__()
|
||||||
|
self.num_output_chunks = num_output_chunks
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
|
||||||
|
def compute_cross_entropy(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
normalize: bool = True, # pylint: disable=unused-argument
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Upcast logits to fp32 and compute cross entropy loss.
|
||||||
|
"""
|
||||||
|
return F.cross_entropy(
|
||||||
|
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
logits (List[torch.Tensor]): List of chunked logits of length
|
||||||
|
``self.num_output_chunks``, where each chunk has shape
|
||||||
|
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
|
||||||
|
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
|
||||||
|
reduction (str): The reduction to apply to the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Cross entropy loss of shape (1,).
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_elements = (labels != self.ignore_index).sum()
|
||||||
|
|
||||||
|
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
|
||||||
|
labels = [
|
||||||
|
target_chunk.reshape(-1)
|
||||||
|
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
|
||||||
|
]
|
||||||
|
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
|
||||||
|
logits = [
|
||||||
|
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
|
||||||
|
]
|
||||||
|
|
||||||
|
# compute one chunk at a time
|
||||||
|
total_loss = 0.0
|
||||||
|
for logits_chunk, labels_chunk in zip(logits, labels):
|
||||||
|
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
|
||||||
|
|
||||||
|
if reduction == "sum":
|
||||||
|
return total_loss
|
||||||
|
return total_loss / total_elements
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||||
|
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
|
||||||
|
loss_fn_ce.compute_cross_entropy = torch.compile(
|
||||||
|
loss_fn_ce.compute_cross_entropy, backend="inductor"
|
||||||
|
)
|
||||||
|
return loss_fn_ce
|
||||||
|
|
||||||
|
|
||||||
|
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||||
|
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
|
||||||
|
|
||||||
|
def chunked_fix_cross_entropy(
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
num_items_in_batch: int = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||||
|
logit_chunks = [ # pylint: disable=unnecessary-comprehension
|
||||||
|
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
|
||||||
|
]
|
||||||
|
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
|
||||||
|
if reduction == "sum":
|
||||||
|
loss = loss / num_items_in_batch
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def for_causal_lm_chunked_loss(
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
vocab_size: int = None, # pylint: disable=unused-argument
|
||||||
|
num_items_in_batch: Optional[int] = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
shift_labels: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# skip the upcast to float since we handle that in the chunking loss
|
||||||
|
if shift_labels is None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
labels = F.pad(labels, (0, 1), value=ignore_index)
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
# Skip Flattening the tokens
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(logits.device)
|
||||||
|
loss = chunked_fix_cross_entropy(
|
||||||
|
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
return for_causal_lm_chunked_loss
|
||||||
|
|
||||||
|
|
||||||
|
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
|
||||||
|
import transformers.loss.loss_utils
|
||||||
|
|
||||||
|
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
|
||||||
|
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
|
||||||
|
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
|
||||||
|
for_causal_lm_chunked_loss
|
||||||
|
)
|
||||||
0
src/axolotl/monkeypatch/models/gemma3/__init__.py
Normal file
0
src/axolotl/monkeypatch/models/gemma3/__init__.py
Normal file
16
src/axolotl/monkeypatch/models/gemma3/modeling.py
Normal file
16
src/axolotl/monkeypatch/models/gemma3/modeling.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Monkeypatch for gemma3 conditional generation forward to fix high loss"""
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma3_conditional_generation_forward():
|
||||||
|
# Remove when https://github.com/huggingface/transformers/pull/37208 merged
|
||||||
|
|
||||||
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
|
Gemma3ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs", False)
|
||||||
|
|
||||||
|
def unpatch():
|
||||||
|
delattr(Gemma3ForConditionalGeneration, "accepts_loss_kwargs")
|
||||||
|
|
||||||
|
return unpatch
|
||||||
@@ -42,6 +42,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
|||||||
if has_remote_code:
|
if has_remote_code:
|
||||||
patch_remote(model_name)
|
patch_remote(model_name)
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
|
# sanity check in case upstream api changes on this
|
||||||
|
assert hasattr(
|
||||||
|
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
||||||
|
), "transformers api changed for _get_unpad_data for flash attention"
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
|||||||
def patch_prepare_data_loader():
|
def patch_prepare_data_loader():
|
||||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||||
|
|
||||||
Raies:
|
Raises:
|
||||||
RuntimeError: If source code to patch does not exist.
|
RuntimeError: If source code to patch does not exist.
|
||||||
"""
|
"""
|
||||||
original_fn = accelerate.data_loader.prepare_data_loader
|
original_fn = accelerate.data_loader.prepare_data_loader
|
||||||
@@ -168,23 +168,34 @@ def patch_prepare_data_loader():
|
|||||||
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(accelerate.data_loader):
|
||||||
|
if item in patched_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
# Create a new function from the patched source
|
# Create a new function from the patched source
|
||||||
namespace = {}
|
namespace = {}
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
patched_source, accelerate.data_loader.__dict__, namespace
|
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
patched_source, globals(), namespace
|
||||||
)
|
)
|
||||||
patched_function = namespace["prepare_data_loader"]
|
|
||||||
|
|
||||||
accelerate.data_loader.prepare_data_loader = patched_function
|
patched_function = namespace["prepare_data_loader"]
|
||||||
|
original_fn.__code__ = patched_function.__code__
|
||||||
|
|
||||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||||
that includes sequence parallelism with the specified degree.
|
that includes sequence parallelism with the specified degree.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||||
|
fsdp: Whether to use FSDP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _prepare_device_mesh(self):
|
def _prepare_device_mesh(self):
|
||||||
@@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
|||||||
)
|
)
|
||||||
device_ids = list(range(world_size))
|
device_ids = list(range(world_size))
|
||||||
|
|
||||||
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
|
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
|
||||||
# parallelism" implementation naming
|
# parallelism" implementation naming.
|
||||||
|
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
|
||||||
|
# only use "fsdp" and "cp" for the device mesh.
|
||||||
return dist.DeviceMesh(
|
return dist.DeviceMesh(
|
||||||
"cuda",
|
"cuda",
|
||||||
torch.tensor(device_ids).reshape(mesh_shape),
|
torch.tensor(device_ids).reshape(mesh_shape),
|
||||||
mesh_dim_names=("dp", "cp"),
|
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace the original method with our new method
|
# Replace the original method with our new method
|
||||||
|
|||||||
@@ -142,7 +142,7 @@ class ProcessingStrategy:
|
|||||||
# TODO: check if it's normal to be single image only for common datasets
|
# TODO: check if it's normal to be single image only for common datasets
|
||||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||||
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
|
||||||
if len(processed_example[image_key]) > 0:
|
if len(processed_example[image_key]) > 1:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Found {len(processed_example[image_key])} images in a sample. Using the first one."
|
f"Found {len(processed_example[image_key])} images in a sample. Using the first one."
|
||||||
"If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages."
|
"If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages."
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
chat_template_kwargs = {
|
chat_template_kwargs = {
|
||||||
"chat_template": self.chat_template,
|
"chat_template": self.chat_template,
|
||||||
"add_generation_prompt": add_generation_prompt,
|
"add_generation_prompt": add_generation_prompt,
|
||||||
|
**self.chat_template_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.trainer import Trainer
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.common.datasets import TrainDatasetMeta
|
from axolotl.common.datasets import TrainDatasetMeta
|
||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
@@ -219,10 +218,14 @@ def execute_training(
|
|||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
|
gather_outputs=cfg.rl is RLType.GRPO,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
|
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
|
||||||
|
# if cfg.bf16:
|
||||||
|
# torch.set_default_dtype(torch.bfloat16)
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
@@ -545,8 +548,6 @@ def train(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (model, tokenizer) after training
|
Tuple of (model, tokenizer) after training
|
||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
|
||||||
(
|
(
|
||||||
trainer,
|
trainer,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -12,8 +12,6 @@ from transformers.utils import ModelOutput
|
|||||||
|
|
||||||
from axolotl.monkeypatch.ring_attn import (
|
from axolotl.monkeypatch.ring_attn import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
patch_prepare_data_loader,
|
|
||||||
patch_prepare_device_mesh,
|
|
||||||
register_ring_attn,
|
register_ring_attn,
|
||||||
update_ring_attn_params,
|
update_ring_attn_params,
|
||||||
)
|
)
|
||||||
@@ -174,6 +172,8 @@ class SequenceParallelContextManager:
|
|||||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
`varlen_llama3` `ring_flash_attn` implementation.
|
`varlen_llama3` `ring_flash_attn` implementation.
|
||||||
|
gather_outputs: Whether to gather outputs after model forward pass across the
|
||||||
|
sequence parallel group.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -183,12 +183,15 @@ class SequenceParallelContextManager:
|
|||||||
gradient_accumulation_steps: int,
|
gradient_accumulation_steps: int,
|
||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
|
gather_outputs: bool,
|
||||||
):
|
):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
self.sequence_parallel_degree = sequence_parallel_degree
|
||||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
|
self.gather_outputs = gather_outputs
|
||||||
|
|
||||||
self._register_ring_attn()
|
self._register_ring_attn()
|
||||||
|
|
||||||
# Set distributed info for local rank
|
# Set distributed info for local rank
|
||||||
@@ -233,12 +236,6 @@ class SequenceParallelContextManager:
|
|||||||
ring_attn_func=self.ring_attn_func,
|
ring_attn_func=self.ring_attn_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Patches for accelerate functionality
|
|
||||||
patch_prepare_data_loader()
|
|
||||||
patch_prepare_device_mesh(
|
|
||||||
sequence_parallel_degree=self.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
|
|
||||||
def _register_model_hooks(self):
|
def _register_model_hooks(self):
|
||||||
# Forward pre-hook to apply sequence parallelism
|
# Forward pre-hook to apply sequence parallelism
|
||||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||||
@@ -277,16 +274,17 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# Register both hooks
|
# Register hooks
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
self.hook_handles.append(
|
self.hook_handles.append(
|
||||||
model.register_forward_pre_hook(
|
model.register_forward_pre_hook(
|
||||||
sequence_parallel_pre_hook, with_kwargs=True
|
sequence_parallel_pre_hook, with_kwargs=True
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.hook_handles.append(
|
if self.gather_outputs:
|
||||||
model.register_forward_hook(sequence_parallel_post_hook)
|
self.hook_handles.append(
|
||||||
)
|
model.register_forward_hook(sequence_parallel_post_hook)
|
||||||
|
)
|
||||||
|
|
||||||
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
def _gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast:
|
||||||
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
|
||||||
|
|||||||
@@ -224,10 +224,10 @@ def wrap_pretraining_dataset(
|
|||||||
remove_columns = []
|
remove_columns = []
|
||||||
if dataset.features is None:
|
if dataset.features is None:
|
||||||
for first_row in dataset:
|
for first_row in dataset:
|
||||||
remove_columns = first_row.keys()
|
remove_columns = list(first_row.keys())
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
remove_columns = dataset.features.keys()
|
remove_columns = list(dataset.features.keys())
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
@@ -267,6 +267,7 @@ def encode_packed_pretraining(
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
batch_max_len=batch_size * max_seq_length,
|
batch_max_len=batch_size * max_seq_length,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
num_processes=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
chunked_data = defaultdict(list)
|
||||||
|
|||||||
@@ -334,7 +334,10 @@ def _load_raw_datasets(
|
|||||||
dataset = merge_datasets(datasets, cfg)
|
dataset = merge_datasets(datasets, cfg)
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
dataset = drop_long_seq_in_dataset(dataset, cfg)
|
if split == "test" and cfg.eval_sequence_len:
|
||||||
|
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||||
|
else:
|
||||||
|
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
|
|||||||
@@ -524,13 +524,24 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
|||||||
Merged dataset.
|
Merged dataset.
|
||||||
"""
|
"""
|
||||||
if len(datasets) == 1:
|
if len(datasets) == 1:
|
||||||
return datasets[0]
|
ds = datasets[0]
|
||||||
|
|
||||||
|
# Do not shuffle if curriculum sampling is enabled
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
return ds
|
||||||
|
|
||||||
|
return ds.shuffle(seed=cfg.seed)
|
||||||
|
|
||||||
LOG.info("Merging datasets...")
|
LOG.info("Merging datasets...")
|
||||||
merged_dataset = concatenate_datasets(datasets)
|
merged_dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
LOG.debug("Shuffling merged datasets...")
|
LOG.debug("Shuffling merged datasets...")
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
LOG.warning(
|
||||||
|
"Shuffling merged datasets with curriculum sampling is not recommended. "
|
||||||
|
"This will randomize the order of samples."
|
||||||
|
)
|
||||||
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
||||||
else:
|
else:
|
||||||
LOG.debug("Not shuffling merged datasets.")
|
LOG.debug("Not shuffling merged datasets.")
|
||||||
|
|||||||
@@ -148,11 +148,14 @@ def deduplicate_and_log_datasets(
|
|||||||
return dataset, other_dataset
|
return dataset, other_dataset
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
def drop_long_seq_in_dataset(
|
||||||
|
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||||
|
) -> Dataset:
|
||||||
"""Remove sequences longer than configured maximum from dataset.
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: Dataset to filter.
|
dataset: Dataset to filter.
|
||||||
|
sequence_len: Maximum length for sequences to keep
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -167,7 +170,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
|||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
sequence_len=cfg.sequence_len,
|
sequence_len=sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len,
|
min_sequence_len=cfg.min_sample_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,7 +190,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
|||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
|
|||||||
@@ -46,16 +46,23 @@ def get_current_device() -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_state():
|
||||||
|
global distributed_state # pylint: disable=global-statement
|
||||||
|
if distributed_state is None:
|
||||||
|
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||||
|
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||||
|
|
||||||
|
|
||||||
def get_distributed_state() -> PartialState | None:
|
def get_distributed_state() -> PartialState | None:
|
||||||
return distributed_state
|
return distributed_state
|
||||||
|
|
||||||
|
|
||||||
def is_distributed() -> bool:
|
def is_distributed() -> bool:
|
||||||
"""Check if distributed training is initialized."""
|
"""Check if distributed training is initialized."""
|
||||||
global distributed_state # pylint: disable=global-statement
|
init_distributed_state()
|
||||||
|
|
||||||
if distributed_state is None:
|
if distributed_state is None:
|
||||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
return False
|
||||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
|
||||||
|
|
||||||
return distributed_state.use_distributed and distributed_state.initialized
|
return distributed_state.use_distributed and distributed_state.initialized
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
|
|||||||
token_ids = [token_ids]
|
token_ids = [token_ids]
|
||||||
|
|
||||||
if skip_special_tokens:
|
if skip_special_tokens:
|
||||||
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
|
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||||
|
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
|
||||||
|
)
|
||||||
|
|
||||||
# to_string returns a string with special tokens
|
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||||
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
|
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
|
||||||
|
)
|
||||||
|
|
||||||
def _create_mistral_chat_completion_request(
|
def _create_mistral_chat_completion_request(
|
||||||
self, conversation: list[dict], tools: list[dict] | None = None
|
self, conversation: list[dict], tools: list[dict] | None = None
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ def pack_parallel(
|
|||||||
bin_size: int,
|
bin_size: int,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
safe_mode: bool = True,
|
safe_mode: bool = True,
|
||||||
mp_start_method: str | None = "spawn",
|
mp_start_method: str | None = "fork",
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
"""Pack sequences into bins using parallel processing.
|
"""Pack sequences into bins using parallel processing.
|
||||||
|
|
||||||
@@ -260,12 +260,13 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
lengths: np.ndarray, # Sequence lengths
|
lengths: np.ndarray, # Sequence lengths
|
||||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||||
num_count_samples: int = 8, # Number of times to estimate batch count
|
num_count_samples: int = 4, # Number of times to estimate batch count
|
||||||
sequential: bool = False, # Whether to use sequential packing
|
sequential: bool = False, # Whether to use sequential packing
|
||||||
group_size: int = 100_000, # Size of groups for parallel packing
|
group_size: int = 100_000, # Size of groups for parallel packing
|
||||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||||
num_processes: int | None = None, # Number of processes for parallel packing
|
num_processes: int | None = None, # Number of processes for parallel packing
|
||||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||||
|
mp_start_method: str = "fork",
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.bin_size = bin_size
|
self.bin_size = bin_size
|
||||||
self.num_processes = num_processes
|
self.num_processes = num_processes
|
||||||
self.safe_mode = safe_mode
|
self.safe_mode = safe_mode
|
||||||
|
self.mp_start_method = mp_start_method
|
||||||
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
@@ -333,13 +335,15 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
||||||
else:
|
else:
|
||||||
# Use parallel packing
|
# Use parallel packing
|
||||||
|
num_processes = self.num_processes or 1
|
||||||
all_bins = pack_parallel(
|
all_bins = pack_parallel(
|
||||||
lengths,
|
lengths,
|
||||||
bin_capacity=self.batch_max_len,
|
bin_capacity=self.batch_max_len,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bin_size=self.bin_size,
|
bin_size=self.bin_size,
|
||||||
num_processes=self.num_processes,
|
num_processes=min(4, num_processes) if num_processes else 4,
|
||||||
safe_mode=self.safe_mode,
|
safe_mode=self.safe_mode,
|
||||||
|
mp_start_method=self.mp_start_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map bin indices back to original indices
|
# Map bin indices back to original indices
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ class AxolotlInputConfig(
|
|||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
dpo_norm_loss: bool | None = None
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
|
dpo_generate_during_eval: bool | None = None
|
||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
@@ -366,6 +367,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
eval_sequence_len: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
|
||||||
|
},
|
||||||
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int = Field(
|
||||||
default=512,
|
default=512,
|
||||||
@@ -393,6 +400,12 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: bool | None = Field(
|
eval_sample_packing: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -523,6 +536,19 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chunked_cross_entropy: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use chunked cross entropy loss for memory efficiency"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chunked_cross_entropy_num_chunks: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of chunks to use for chunked cross entropy loss"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
llama4_linearized_experts: bool | None = None
|
llama4_linearized_experts: bool | None = None
|
||||||
|
|
||||||
deepspeed: str | dict[str, Any] | None = Field(
|
deepspeed: str | dict[str, Any] | None = Field(
|
||||||
@@ -759,6 +785,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
|
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Additional kwargs to pass to the chat template. This is useful for customizing the chat template. For example, you can pass `thinking=False` to add a generation prompt to the chat template."
|
||||||
|
},
|
||||||
|
)
|
||||||
eot_tokens: list[str] | None = Field(
|
eot_tokens: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class ChatTemplate(str, Enum):
|
|||||||
jinja = "jinja"
|
jinja = "jinja"
|
||||||
qwen_25 = "qwen_25"
|
qwen_25 = "qwen_25"
|
||||||
qwen3 = "qwen3"
|
qwen3 = "qwen3"
|
||||||
|
falcon_h1 = "falcon_h1"
|
||||||
tokenizer_default = "tokenizer_default"
|
tokenizer_default = "tokenizer_default"
|
||||||
exaone = "exaone"
|
exaone = "exaone"
|
||||||
metharme = "metharme"
|
metharme = "metharme"
|
||||||
|
|||||||
@@ -462,6 +462,20 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def pretrain_with_tps(cls, data):
|
||||||
|
if data.get("pretraining_dataset") and data.get(
|
||||||
|
"include_tokens_per_second", False
|
||||||
|
):
|
||||||
|
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
|
||||||
|
# due to trying to count the number of tokens total in the dataset
|
||||||
|
raise ValueError(
|
||||||
|
"pretraining_dataset and include_tokens_per_second cannot be used together."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
@@ -381,6 +381,7 @@ def process_pretraining_datasets_for_packing(
|
|||||||
if not skip_position_ids:
|
if not skip_position_ids:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
|
batched=True,
|
||||||
desc="Add position_id column (Pretraining Sample Packing)",
|
desc="Add position_id column (Pretraining Sample Packing)",
|
||||||
)
|
)
|
||||||
if drop_attention_mask:
|
if drop_attention_mask:
|
||||||
@@ -467,6 +468,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
sequential=cfg.sample_packing_sequentially,
|
sequential=cfg.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=cfg.dataset_processes,
|
num_processes=cfg.dataset_processes,
|
||||||
|
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
@@ -537,6 +539,12 @@ def setup_deepspeed_env(cfg, stage=None):
|
|||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
if stage == 3:
|
if stage == 3:
|
||||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
|
||||||
|
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||||
|
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||||
|
# to model load.
|
||||||
|
init_distributed_state()
|
||||||
|
|
||||||
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
# If we don't assign this, it doesn't actually get set in the accelerate weakref
|
||||||
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ shared pytest fixtures
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pytest
|
import pytest
|
||||||
@@ -24,6 +26,8 @@ from tests.hf_offline_utils import (
|
|||||||
hf_offline_context,
|
hf_offline_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -411,7 +415,16 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir():
|
def temp_dir() -> Generator[str, None, None]:
|
||||||
|
# Create a temporary directory
|
||||||
|
_temp_dir = tempfile.mkdtemp()
|
||||||
|
yield _temp_dir
|
||||||
|
# Clean up the directory after the test
|
||||||
|
shutil.rmtree(_temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def module_temp_dir() -> Generator[str, None, None]:
|
||||||
# Create a temporary directory
|
# Create a temporary directory
|
||||||
_temp_dir = tempfile.mkdtemp()
|
_temp_dir = tempfile.mkdtemp()
|
||||||
yield _temp_dir
|
yield _temp_dir
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class TestSequenceParallelism:
|
|||||||
"micro_batch_size": micro_batch_size,
|
"micro_batch_size": micro_batch_size,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class TestPackedFlex:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -309,6 +309,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"warmup_steps": 10,
|
"warmup_steps": 10,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -400,6 +401,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"warmup_steps": 10,
|
"warmup_steps": 10,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -38,12 +38,13 @@ class TestMultiGPUEval:
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
"val_set_size": 0.004,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
"path": "teknium/GPT4-LLM-Cleaned",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
|
"split": "train[:5%]",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -51,6 +52,7 @@ class TestMultiGPUEval:
|
|||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -107,12 +109,13 @@ class TestMultiGPUEval:
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
"val_set_size": 0.0004,
|
"val_set_size": 0.01,
|
||||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
"path": "teknium/GPT4-LLM-Cleaned",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
|
"split": "train[:5%]",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -120,6 +123,7 @@ class TestMultiGPUEval:
|
|||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ class TestMultiGPUGemma3:
|
|||||||
},
|
},
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
E2E tests for multigpu lora tinyllama
|
E2E tests for multigpu lora tinyllama
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -25,6 +27,60 @@ def download_model():
|
|||||||
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def sft_base_cfg():
|
||||||
|
cfg = DictDefault(
|
||||||
|
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||||
|
tokenizer_config="HuggingFaceTB/SmolLM2-135M", # this has to be manually set since we haven't done validation
|
||||||
|
sequence_len=1024,
|
||||||
|
special_tokens={
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
datasets=[
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
val_set_size=0.1,
|
||||||
|
sample_packing=True,
|
||||||
|
flash_attention=True,
|
||||||
|
learning_rate=0.00001,
|
||||||
|
optimizer="adamw_8bit",
|
||||||
|
seed=42,
|
||||||
|
# these need to be set since we aren't running schema validation
|
||||||
|
micro_batch_size=2,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", name="sft_prepared_dataset_alpaca_cfg")
|
||||||
|
def sft_prepared_dataset_alpaca_cfg(module_temp_dir, sft_base_cfg):
|
||||||
|
dataset_prepared_path = module_temp_dir + "/last_run_prepared"
|
||||||
|
cfg = sft_base_cfg | DictDefault(
|
||||||
|
dataset_prepared_path=dataset_prepared_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
Path(module_temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(module_temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"preprocess",
|
||||||
|
str(Path(module_temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# unset flash attention since we have some flex attention tests too
|
||||||
|
cfg.flash_attention = None
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def transformers_version_eq(required_version):
|
def transformers_version_eq(required_version):
|
||||||
return version.parse(transformers.__version__) == version.parse(required_version)
|
return version.parse(transformers.__version__) == version.parse(required_version)
|
||||||
|
|
||||||
@@ -62,6 +118,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -96,44 +153,36 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
)
|
)
|
||||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
def test_lora_ddp_packed(
|
||||||
|
self, temp_dir, sft_prepared_dataset_alpaca_cfg, gradient_accumulation_steps
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = (
|
||||||
{
|
DictDefault(
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
{
|
||||||
"sequence_len": 2048,
|
"eval_sample_packing": False,
|
||||||
"sample_packing": True,
|
"pad_to_sequence_len": True,
|
||||||
"eval_sample_packing": False,
|
"adapter": "lora",
|
||||||
"pad_to_sequence_len": True,
|
"lora_r": 8,
|
||||||
"adapter": "lora",
|
"lora_alpha": 16,
|
||||||
"lora_r": 8,
|
"lora_dropout": 0.05,
|
||||||
"lora_alpha": 16,
|
"lora_target_linear": True,
|
||||||
"lora_dropout": 0.05,
|
"val_set_size": 0.05,
|
||||||
"lora_target_linear": True,
|
"num_epochs": 1,
|
||||||
"val_set_size": 0.05,
|
"max_steps": 2,
|
||||||
"special_tokens": {
|
"micro_batch_size": 1,
|
||||||
"pad_token": "<|endoftext|>",
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
},
|
# "gradient_checkpointing": True,
|
||||||
"datasets": [
|
"output_dir": temp_dir,
|
||||||
{
|
"learning_rate": 0.00001,
|
||||||
"path": "tatsu-lab/alpaca",
|
"optimizer": "adamw_8bit",
|
||||||
"type": "alpaca",
|
"lr_scheduler": "cosine",
|
||||||
"split": "train[:20%]",
|
"flash_attention": True,
|
||||||
},
|
"use_tensorboard": True,
|
||||||
],
|
"bf16": True,
|
||||||
"num_epochs": 1,
|
}
|
||||||
"max_steps": 2,
|
)
|
||||||
"micro_batch_size": 1,
|
| sft_prepared_dataset_alpaca_cfg
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
||||||
# "gradient_checkpointing": True,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"bf16": True,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
@@ -200,6 +249,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -278,6 +328,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -340,6 +391,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -380,58 +432,50 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"fsdp_state_dict_type",
|
"fsdp_state_dict_type",
|
||||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||||
)
|
)
|
||||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
def test_fsdp_packed(
|
||||||
|
self, temp_dir, sft_prepared_dataset_alpaca_cfg, fsdp_state_dict_type
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = (
|
||||||
{
|
DictDefault(
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
{
|
||||||
"sample_packing": True,
|
"pad_to_sequence_len": True,
|
||||||
"pad_to_sequence_len": True,
|
"num_epochs": 1,
|
||||||
"sequence_len": 1024,
|
"max_steps": 2,
|
||||||
"val_set_size": 0.05,
|
"micro_batch_size": 2,
|
||||||
"special_tokens": {
|
"gradient_accumulation_steps": 2,
|
||||||
"pad_token": "<|endoftext|>",
|
# "gradient_checkpointing": True,
|
||||||
},
|
"output_dir": temp_dir,
|
||||||
"datasets": [
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
{
|
"learning_rate": 0.00001,
|
||||||
"path": "tatsu-lab/alpaca",
|
"optimizer": "adamw_torch_fused",
|
||||||
"type": "alpaca",
|
"lr_scheduler": "cosine",
|
||||||
"split": "train[:10%]",
|
"flash_attention": True,
|
||||||
|
"fsdp": [
|
||||||
|
"full_shard",
|
||||||
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_limit_all_gathers": True,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_use_orig_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
],
|
"use_tensorboard": True,
|
||||||
"num_epochs": 1,
|
}
|
||||||
"max_steps": 2,
|
)
|
||||||
"micro_batch_size": 2,
|
| sft_prepared_dataset_alpaca_cfg
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
# "gradient_checkpointing": True,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"fsdp": [
|
|
||||||
"full_shard",
|
|
||||||
"auto_wrap",
|
|
||||||
],
|
|
||||||
"fsdp_config": {
|
|
||||||
"fsdp_limit_all_gathers": True,
|
|
||||||
"fsdp_offload_params": False,
|
|
||||||
"fsdp_sync_module_states": True,
|
|
||||||
"fsdp_use_orig_params": False,
|
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
|
||||||
"fsdp_state_dict_type": fsdp_state_dict_type,
|
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
},
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
@@ -452,7 +496,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.4, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_6_0
|
@require_torch_2_6_0
|
||||||
@@ -465,50 +509,43 @@ class TestMultiGPULlama:
|
|||||||
[True, False],
|
[True, False],
|
||||||
)
|
)
|
||||||
def test_fsdp2_packed(
|
def test_fsdp2_packed(
|
||||||
self, temp_dir, attention_backend, fsdp_reshard_after_forward
|
self,
|
||||||
|
temp_dir,
|
||||||
|
sft_prepared_dataset_alpaca_cfg,
|
||||||
|
attention_backend,
|
||||||
|
fsdp_reshard_after_forward,
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = (
|
||||||
{
|
DictDefault(
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
{
|
||||||
"sample_packing": True,
|
"pad_to_sequence_len": True,
|
||||||
"pad_to_sequence_len": True,
|
"num_epochs": 1,
|
||||||
"sequence_len": 2048,
|
"max_steps": 2,
|
||||||
"val_set_size": 0.1,
|
"micro_batch_size": 4,
|
||||||
"special_tokens": {
|
"gradient_accumulation_steps": 2,
|
||||||
"pad_token": "<|endoftext|>",
|
"gradient_checkpointing": True,
|
||||||
},
|
"output_dir": temp_dir,
|
||||||
"datasets": [
|
"learning_rate": 0.00001,
|
||||||
{
|
"optimizer": "adamw_torch_8bit",
|
||||||
"path": "tatsu-lab/alpaca",
|
"lr_scheduler": "cosine",
|
||||||
"type": "alpaca",
|
"fsdp": [
|
||||||
"split": "train[:10%]",
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_version": 2,
|
||||||
|
# "fsdp_forward_prefetch": True, # not yet implemented in accelerate
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
||||||
},
|
},
|
||||||
],
|
"use_tensorboard": True,
|
||||||
"num_epochs": 1,
|
}
|
||||||
"max_steps": 2,
|
)
|
||||||
"micro_batch_size": 4,
|
| sft_prepared_dataset_alpaca_cfg
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"fsdp": [
|
|
||||||
"auto_wrap",
|
|
||||||
],
|
|
||||||
"fsdp_config": {
|
|
||||||
"fsdp_version": 2,
|
|
||||||
# "fsdp_forward_prefetch": True, # not yet implemented in accelerate
|
|
||||||
"fsdp_offload_params": False,
|
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
|
||||||
},
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
if attention_backend == "flash":
|
if attention_backend == "flash":
|
||||||
cfg.flash_attention = True
|
cfg.flash_attention = True
|
||||||
@@ -536,63 +573,55 @@ class TestMultiGPULlama:
|
|||||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(
|
||||||
|
self, temp_dir, sft_prepared_dataset_alpaca_cfg
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = (
|
||||||
{
|
DictDefault(
|
||||||
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
{
|
||||||
"adapter": "qlora",
|
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
||||||
"mean_resizing_embeddings": True,
|
"adapter": "qlora",
|
||||||
"load_in_4bit": True,
|
"mean_resizing_embeddings": True,
|
||||||
"lora_r": 8,
|
"load_in_4bit": True,
|
||||||
"lora_alpha": 16,
|
"lora_r": 8,
|
||||||
"lora_dropout": 0.05,
|
"lora_alpha": 16,
|
||||||
"lora_target_linear": True,
|
"lora_dropout": 0.05,
|
||||||
# "lora_modules_to_save": [
|
"lora_target_linear": True,
|
||||||
# "embed_tokens",
|
# "lora_modules_to_save": [
|
||||||
# "lm_head",
|
# "embed_tokens",
|
||||||
# ],
|
# "lm_head",
|
||||||
"sample_packing": True,
|
# ],
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 1024,
|
"num_epochs": 1,
|
||||||
"val_set_size": 0.01,
|
"max_steps": 2,
|
||||||
"special_tokens": {
|
"micro_batch_size": 2,
|
||||||
"pad_token": "<|endoftext|>",
|
"gradient_accumulation_steps": 2,
|
||||||
},
|
# "gradient_checkpointing": True,
|
||||||
"datasets": [
|
"output_dir": temp_dir,
|
||||||
{
|
"learning_rate": 0.00001,
|
||||||
"path": "tatsu-lab/alpaca",
|
"optimizer": "adamw_torch_fused",
|
||||||
"type": "alpaca",
|
"lr_scheduler": "cosine",
|
||||||
"split": "train[:10%]",
|
"flash_attention": True,
|
||||||
|
"fsdp": [
|
||||||
|
"full_shard",
|
||||||
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_limit_all_gathers": True,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_use_orig_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": True,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
],
|
"use_tensorboard": True,
|
||||||
"num_epochs": 1,
|
}
|
||||||
"max_steps": 2,
|
)
|
||||||
"micro_batch_size": 2,
|
| sft_prepared_dataset_alpaca_cfg
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
# "gradient_checkpointing": True,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"fsdp": [
|
|
||||||
"full_shard",
|
|
||||||
"auto_wrap",
|
|
||||||
],
|
|
||||||
"fsdp_config": {
|
|
||||||
"fsdp_limit_all_gathers": True,
|
|
||||||
"fsdp_offload_params": False,
|
|
||||||
"fsdp_sync_module_states": True,
|
|
||||||
"fsdp_use_orig_params": False,
|
|
||||||
"fsdp_cpu_ram_efficient_loading": True,
|
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
},
|
|
||||||
"use_tensorboard": True,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
@@ -633,7 +662,12 @@ class TestMultiGPULlama:
|
|||||||
[True, False],
|
[True, False],
|
||||||
)
|
)
|
||||||
def test_ds_zero3_packed(
|
def test_ds_zero3_packed(
|
||||||
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
|
self,
|
||||||
|
temp_dir,
|
||||||
|
sft_prepared_dataset_alpaca_cfg,
|
||||||
|
gradient_accumulation_steps,
|
||||||
|
deepspeed,
|
||||||
|
qlora,
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if qlora:
|
if qlora:
|
||||||
@@ -647,36 +681,25 @@ class TestMultiGPULlama:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
cfg = DictDefault(
|
cfg = (
|
||||||
{
|
DictDefault(
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
{
|
||||||
"sample_packing": True,
|
"pad_to_sequence_len": True,
|
||||||
"pad_to_sequence_len": True,
|
"num_epochs": 1,
|
||||||
"sequence_len": 1024,
|
"max_steps": 2,
|
||||||
"val_set_size": 0.05,
|
"micro_batch_size": 1,
|
||||||
"special_tokens": {
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"pad_token": "<|endoftext|>",
|
"output_dir": temp_dir,
|
||||||
},
|
"learning_rate": 0.00001,
|
||||||
"datasets": [
|
"optimizer": "adamw_torch_fused",
|
||||||
{
|
"lr_scheduler": "cosine",
|
||||||
"path": "tatsu-lab/alpaca",
|
"flash_attention": True,
|
||||||
"type": "alpaca",
|
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
|
||||||
"split": "train[:10%]",
|
"use_tensorboard": True,
|
||||||
},
|
**adapter,
|
||||||
],
|
}
|
||||||
"num_epochs": 1,
|
)
|
||||||
"max_steps": 2,
|
| sft_prepared_dataset_alpaca_cfg
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
|
|
||||||
"use_tensorboard": True,
|
|
||||||
**adapter,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
@@ -697,7 +720,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.4, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -708,7 +731,13 @@ class TestMultiGPULlama:
|
|||||||
"qlora",
|
"qlora",
|
||||||
[True, False],
|
[True, False],
|
||||||
)
|
)
|
||||||
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
def test_ds_zero2_packed(
|
||||||
|
self,
|
||||||
|
temp_dir,
|
||||||
|
sft_prepared_dataset_alpaca_cfg,
|
||||||
|
gradient_accumulation_steps,
|
||||||
|
qlora,
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if qlora:
|
if qlora:
|
||||||
adapter = {
|
adapter = {
|
||||||
@@ -721,36 +750,25 @@ class TestMultiGPULlama:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
cfg = DictDefault(
|
cfg = (
|
||||||
{
|
DictDefault(
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
{
|
||||||
"sample_packing": True,
|
"pad_to_sequence_len": True,
|
||||||
"pad_to_sequence_len": True,
|
"num_epochs": 1,
|
||||||
"sequence_len": 1024,
|
"max_steps": 2,
|
||||||
"val_set_size": 0.01,
|
"micro_batch_size": 1,
|
||||||
"special_tokens": {
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"pad_token": "<|endoftext|>",
|
"output_dir": temp_dir,
|
||||||
},
|
"learning_rate": 0.00001,
|
||||||
"datasets": [
|
"optimizer": "adamw_torch_fused",
|
||||||
{
|
"lr_scheduler": "cosine",
|
||||||
"path": "tatsu-lab/alpaca",
|
"flash_attention": True,
|
||||||
"type": "alpaca",
|
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
||||||
"split": "train[:10%]",
|
"use_tensorboard": True,
|
||||||
},
|
**adapter,
|
||||||
],
|
}
|
||||||
"num_epochs": 1,
|
)
|
||||||
"max_steps": 2,
|
| sft_prepared_dataset_alpaca_cfg
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
|
||||||
"use_tensorboard": True,
|
|
||||||
**adapter,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
@@ -771,7 +789,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -782,7 +800,13 @@ class TestMultiGPULlama:
|
|||||||
"qlora",
|
"qlora",
|
||||||
[True, False],
|
[True, False],
|
||||||
)
|
)
|
||||||
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
def test_ds_zero1_packed(
|
||||||
|
self,
|
||||||
|
temp_dir,
|
||||||
|
sft_prepared_dataset_alpaca_cfg,
|
||||||
|
gradient_accumulation_steps,
|
||||||
|
qlora,
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if qlora:
|
if qlora:
|
||||||
adapter = {
|
adapter = {
|
||||||
@@ -795,36 +819,25 @@ class TestMultiGPULlama:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adapter = {}
|
adapter = {}
|
||||||
cfg = DictDefault(
|
cfg = (
|
||||||
{
|
DictDefault(
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
{
|
||||||
"sample_packing": True,
|
"pad_to_sequence_len": True,
|
||||||
"pad_to_sequence_len": True,
|
"num_epochs": 1,
|
||||||
"sequence_len": 1024,
|
"max_steps": 2,
|
||||||
"val_set_size": 0.01,
|
"micro_batch_size": 1,
|
||||||
"special_tokens": {
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"pad_token": "<|endoftext|>",
|
"output_dir": temp_dir,
|
||||||
},
|
"learning_rate": 0.00001,
|
||||||
"datasets": [
|
"optimizer": "adamw_torch_fused",
|
||||||
{
|
"lr_scheduler": "cosine",
|
||||||
"path": "tatsu-lab/alpaca",
|
"flash_attention": True,
|
||||||
"type": "alpaca",
|
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||||
"split": "train[:10%]",
|
"use_tensorboard": True,
|
||||||
},
|
**adapter,
|
||||||
],
|
}
|
||||||
"num_epochs": 1,
|
)
|
||||||
"max_steps": 2,
|
| sft_prepared_dataset_alpaca_cfg
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
|
||||||
"use_tensorboard": True,
|
|
||||||
**adapter,
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# write cfg to yaml file
|
# write cfg to yaml file
|
||||||
@@ -845,7 +858,7 @@ class TestMultiGPULlama:
|
|||||||
)
|
)
|
||||||
|
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class TestMultiGPUQwen2:
|
|||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class TestMultiGPURay:
|
|||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -107,6 +108,7 @@ class TestMultiGPURay:
|
|||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
def test_kernel_training_integration():
|
def test_kernel_training_integration(temp_dir):
|
||||||
"""Test model loading with kernel patches enabled."""
|
"""Test model loading with kernel patches enabled."""
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
|
||||||
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Write cfg to yaml file
|
||||||
|
path = Path(temp_dir) / "config.yaml"
|
||||||
|
with open(path, "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
assert found_patched_attn
|
assert found_patched_attn
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_training_integration_dropout_non_zero():
|
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||||
"""Test model loading with dropout non-zero should not patch."""
|
"""Test model loading with dropout non-zero should not patch."""
|
||||||
|
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Write cfg to yaml file
|
||||||
|
path = Path(temp_dir) / "config.yaml"
|
||||||
|
with open(path, "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
# Get original attention class
|
# Get original attention class
|
||||||
attention_cls = get_attention_cls_from_config(cfg)
|
attention_cls = get_attention_cls_from_config(cfg)
|
||||||
|
|
||||||
|
|||||||
40
tests/test_chunked_xentropy.py
Normal file
40
tests/test_chunked_xentropy.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
test suite for chunked cross entropy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chunked_fixtures():
|
||||||
|
model_dim = 512
|
||||||
|
vocab_size = 1024 * 256
|
||||||
|
seq_len = 2048
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
lm_head = nn.Linear(model_dim, vocab_size)
|
||||||
|
hidden_state = torch.randn(batch_size, seq_len, model_dim)
|
||||||
|
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
|
||||||
|
return lm_head, hidden_state, labels, vocab_size
|
||||||
|
|
||||||
|
|
||||||
|
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
|
||||||
|
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
|
||||||
|
lm_loss = get_causal_lm_loss()
|
||||||
|
|
||||||
|
logits = lm_head(hidden_state)
|
||||||
|
|
||||||
|
chunked_lm_loss = lm_loss(logits, labels)
|
||||||
|
|
||||||
|
logits_flattened = logits.view(-1, vocab_size)
|
||||||
|
labels_flattened = labels.view(-1)
|
||||||
|
|
||||||
|
loss = nn.functional.cross_entropy(
|
||||||
|
logits_flattened.float(), labels_flattened, reduction="mean"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)
|
||||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
|||||||
)
|
)
|
||||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||||
|
|
||||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||||
|
|
||||||
lengths = get_dataset_lengths(train_dataset)
|
lengths = get_dataset_lengths(train_dataset)
|
||||||
batch_sampler = MultipackBatchSampler(
|
batch_sampler = MultipackBatchSampler(
|
||||||
|
|||||||
Reference in New Issue
Block a user