Compare commits

..

37 Commits

Author SHA1 Message Date
Wing Lian
3877c5c69d set release version 0.8.0 (#2476)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* set release version 0.8.0

* make sure to include ring-flash-attn in docker image build
2025-04-02 09:50:56 -04:00
NanoCode012
adb593abac fix: document offload gradient_checkpointing option (#2475) 2025-04-02 09:35:42 -04:00
NanoCode012
a0117c9bce fix: separate gemma3 text and vision example config (#2471) [skip ci]
* fix: separate gemma3 text and vision example config

* fix: update to use a text-only dataset

* fix: typo
2025-04-02 09:35:29 -04:00
NanoCode012
e6cfb093d2 fix: disable SP during merge (#2470) [skip ci] 2025-04-02 09:35:00 -04:00
NanoCode012
7abc71dc0b fix: gemma3 loss in forward pass (#2473) [skip ci]
* fix: gemma3 loss in forward pass

* fix: lint

* fix: move patch before plugins

* Update src/axolotl/monkeypatch/gemma3.py

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-04-02 09:34:41 -04:00
NanoCode012
45bf634d17 feat: add support for multimodal in lora kernels (#2472) [skip ci]
* feat: add support for multimodal in lora kernels

* fix: improve multimodal checks

* fix: add fallback for model config

* chor: add gemma3 to docs
2025-04-02 09:33:46 -04:00
NanoCode012
80ba4b69f1 fix: pydantic warning validator not returning self (#2474) 2025-04-02 07:40:49 -04:00
Wing Lian
0bfa180f7d torch 2.7.0 base image for testing (#2467) 2025-04-01 15:38:26 -04:00
NanoCode012
9e22c4ca6a fix: set rl=None during inference (#2463) 2025-04-01 12:25:53 -04:00
NanoCode012
990b5896bc fix: downgrade deepspeed to fix grad checkpoint oom (#2465) [skip ci] 2025-04-01 12:25:05 -04:00
Dan Saunders
7d0eb66b54 fixing eval for SP (#2468) 2025-04-01 11:59:08 -04:00
Wing Lian
df119e3724 Validation for Muon optimizer with DS/FSDP (#2464) 2025-04-01 09:39:12 -04:00
NanoCode012
f4ae8816bb Fix: remove the numerous sequential log (#2461)
* fix: remove sequential logs

* feat(doc): add for sample pack sequentially and curriculum sampling
2025-04-01 09:20:00 -04:00
NanoCode012
9b95e06cbb Fix(doc): Minor doc changes for peft and modal (#2462) [skip ci]
* fix(doc): document peft configs

* fix(doc): explain modal env vs secrets difference

* fix(doc): clarify evaluate vs lm-eval

* fix: clarify what is performance
2025-04-01 08:48:36 -04:00
Wing Lian
e0aba74dd0 Release update 20250331 (#2460) [skip ci]
* make torch 2.6.0 the default image

* fix tests against upstream main

* fix attribute access

* use fixture dataset

* fix dataset load

* correct the fixtures + tests

* more fixtures

* add accidentally removed shakespeare fixture

* fix conversion from unittest to pytest class

* nightly main ci caches

* build 12.6.3 cuda base image

* override for fix from huggingface/transformers#37162

* address PR feedback
2025-04-01 08:47:50 -04:00
Wing Lian
328d598114 gemma3 packing fixes (#2449)
* make gemma3 work with packing

* multi-gpu e2e for ci

* update gemma3 model namespace to use mirror

* add gradient checkpointing to multigpu e2e ci

* update gemma3 examples for use_reentrant and fix ddp find unused params

* fix tests for gemma3

* fix import for test utils

* set correct train loss for gemma3 e2e
2025-03-31 17:15:23 -04:00
DreamGenX
4d36ecc724 Sequential sample packing (#2404) [skip ci]
* add sequential sample packing

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-31 15:48:20 -04:00
NanoCode012
7acf93b59f Fix(doc): Clarify doc on attention configs and missing pad_token (#2455) [skip ci]
* fix: clarify input type

* fix: handling of error message if data_files not available

* fix: clarify attention handling

* fix: add doc on missing pad token
2025-03-31 15:47:28 -04:00
Wing Lian
b6fc46ada8 Updates for trl 0.16.0 - mostly for GRPO (#2437) [skip ci]
* add grpo scale_rewards config for trl#3135

* options to connect to vllm server directly w grpo trl#3094

* temperature support trl#3029

* sampling/generation kwargs for grpo trl#2989

* make vllm_enable_prefix_caching a config param trl#2900

* grpo multi-step optimizeations trl#2899

* remove overrides for grpo trainer

* bump trl to 0.16.0

* add cli  to start vllm-serve via trl

* call the python module directly

* update to use vllm with 2.6.0 too now and call trl vllm serve from module

* vllm 0.8.1

* use python3

* use sys.executable

* remove context and wait for start

* fixes to make it actually work

* fixes so the grpo tests pass with new vllm paradigm

* explicit host/port and check in start vllm

* make sure that vllm doesn't hang by setting quiet so outouts go to dev null

* also bump bnb to latest release

* add option for wait from cli and nccl debugging for ci

* grpo + vllm test on separate devices for now

* make sure grpo + vllm tests runs single worker since pynccl comms would conflict

* fix cli

* remove wait and add caching for argilla dataset

* refactoring configs

* chore: lint

* add vllm config

* fixup vllm grpo args

* fix one more incorrect schema/config path

* fix another vlllm reference and increase timeout

* make the tests run a bit faster

* change mbsz back so it is correct for grpo

* another change mbsz back so it is correct for grpo

* fixing cli args

* nits

* adding docs

* docs

* include tensor parallel size for vllm in pydantic schema

* moving start_vllm, more docs

* limit output len for grpo vllm

* vllm enable_prefix_caching isn't a bool cli arg

* fix env ordering in tests and also use pid check when looking for vllm

---------

Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-03-31 15:47:11 -04:00
Dan Saunders
b35992262e Ray train bugfix (#2458)
* fix nccl pg destroy warning

* update

* ray bugfix
2025-03-31 15:17:43 -04:00
Dan Saunders
ef6eb77cc8 destroy process group on Ctrl+C / training or eval run (#2457)
* fix nccl pg destroy warning

* update
2025-03-31 12:36:47 -04:00
Dan Saunders
5410195e0b Sequence parallelism quick follow-ups; remove ModelCallback (#2450)
* guard return if ring attn alrady registered

* add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers)

* configurable heads_k_stride from ring-flash-attn hf adapter
2025-03-31 09:13:42 -04:00
NanoCode012
cf0c79d52e fix: minor patches for multimodal (#2441)
* fix: update chat_template

* fix: handle gemma3 showing a lot of no content for turn 0

* fix: remove unknown config from examples

* fix: test

* fix: temporary disable gemma2 test

* fix: stop overwriting config.text_config unnecessarily

* fix: handling of set cache to the text_config section

* feat: add liger gemma support and bump liger to 0.5.5

* fix: add double use_cache setting

* fix: add support for final_logit_softcap in CCE for gemma2/3

* fix: set use_cache before model load

* feat: add missing layernorm override

* fix: handle gemma3 rmsnorm

* fix: use wrapper to pass dim as hidden_size

* fix: change dim to positional

* fix: patch with wrong mlp

* chore: refactor use_cache handling

* fix import issues

* fix tests.e2e.utils import

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-03-31 13:40:12 +07:00
Wing Lian
4ba80a0e5a fix streaming packing test (#2454)
* fix streaming packing test

* constrain amount of text generated
2025-03-29 08:30:06 -04:00
Wing Lian
c49682132b use offline for precached stream dataset (#2453) 2025-03-28 23:39:09 -04:00
Wing Lian
e46239f8d3 bump liger to 0.5.5 (#2448) 2025-03-28 19:21:03 -04:00
Wing Lian
05f03b541a hf offline decorator for tests to workaround rate limits (#2452) [skip ci]
* hf offline decorator for tests to workaround rate limits

* fail quicker so we can see logs

* try new cache name

* limit files downloaded

* phi mini predownload

* offline decorator for phi tokenizer

* handle meta llama 8b offline too

* make sure to return fixtures if they are wrapped too

* more fixes

* more things offline

* more offline things

* fix the env var

* fix the model name

* handle gemma also

* force reload of modules to recheck offline status

* prefetch mistral too

* use reset_sessions so hub picks up offline mode

* more fixes

* rename so it doesn't seem like a context manager

* fix backoff

* switch out tinyshakespeare dataset since it runs a py script to fetch data and doesn't work offline

* include additional dataset

* more fixes

* more fixes

* replace tiny shakespeaere dataset

* skip some tests for now

* use more robust check using snapshot download to determine if a dataset name is on the hub

* typo for skip reason

* use local_files_only

* more fixtures

* remove local only

* use tiny shakespeare as pretrain dataset and streaming can't be offline even if precached

* make sure fixtures aren't offline

improve the offline reset
try bumping version of datasets
reorder reloading and setting
prime a new cache
run the tests now with fresh cache
try with a static cache

* now run all the ci again with hopefully a correct cache

* skip wonky tests for now

* skip wonky tests for now

* handle offline mode for model card creation
2025-03-28 19:20:46 -04:00
Wing Lian
a4e430e7c4 add override of upstream fix for multi-gpu orpo (#2440)
* add override of upstream fix

* override batch loss metrics for CPO/Simpo as well
2025-03-26 18:14:59 -04:00
Wing Lian
6cdcb8ddd5 Set the pytorch_cuda_alloc_conf env in the train module (#2447) 2025-03-26 18:14:43 -04:00
NanoCode012
a7811ad4a0 fix(doc): document config required to run eval_causal_lm_metrics (#2445) [skip ci] 2025-03-26 18:14:29 -04:00
NanoCode012
e2da821e67 chore: minor optim changes (add apollo, improve docs, remove lion-pytorch) (#2444)
* feat: add apollo-torch

* chore: update optimizer list

* fix: deleted accidental requirements file

* fix: remove mention of deprecated lion_pytorch
2025-03-26 18:14:07 -04:00
NanoCode012
2c34a4634e feat: add CCE for gemma3, cohere, and cohere2 (#2443)
* feat: add CCE for gemma3 and cohere1/2

* fix: change from relative import to absolute

* feat: add multipack for cohere&cohere2

* chore: improve comments

* fix: add gemma3_text

* feat: add cohere2 example

* fix: cohere forward

* fix: patch for cohere2

* feat: add command r v01 qlora sample

* chore: lint

* feat: upgrade gemma3 and gemma2 patch to use logits_to_keep

* chore: lint

* fix: add deprecate_kwarg decorator

* fix: add cce for gemma3 conditionalgeneration

* fix: gemma3 patch to defer logits calculation

* fix: patch gemma3 if given as model

* fix: remove not working config

* fix: update comments to clarify changes

* feat(doc): add supported models to readme

* fix: address difference in our cohere patch

* feat: add mistral3

* feat: add gemma

* feat(doc): update README to include gemma and mistral3 in supported models

* fix: gemma patch

* fix: import

* fix: gemma patch to be standalone

* fix: gemma3 warn about not support final_logit_softcapping

* feat: add mllama CCE

* chore: add abbireviation to doc

* fix: remove unneeded gemma3 eager warning

* fix: save processor if available

* fix: enable save processor on merge

* fix: wrong env meaning
2025-03-26 18:13:51 -04:00
NanoCode012
a9b0733f2c Feat: Rework multimodal support (mllama, llava, pixtral, qwen2, qwen25, gemma3, mistral3) (#2435) 2025-03-23 11:08:51 -04:00
NanoCode012
9f00465a5c Feat: Add support for gemma3_text and add e2e for gemma2 (#2406) 2025-03-22 20:33:21 -04:00
Dan Saunders
86bac48d14 cleanup for failing test (#2436) 2025-03-22 17:53:29 -04:00
Dan Saunders
e44953d50c installing axolotl prior to quartodoc build (#2434)
* installing axolotl prior to quartodoc build

* simplify by installing no deps

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-03-21 13:28:13 -04:00
Dan Saunders
23f0c51d88 Sequence parallelism (#2412)
* adding easy_context as integration for now

* progress on ring attn impl

* progress on ring attn impl

* cleanup

* remove errant file

* fix req

* removing unused code

* updates

* pytest

* update

* updates

* fixes

* precommit fixes

* working multi-group SP

* fixing sample packing

* remove debug logs and simplify

* eval dataloader and sampler changes

* removing some obvious comments

* update config.qmd and rename option

* scoping down problematic import

* another import scoping change

* pernicious Fire CLI bugfix

* isolate cli tests

* actually isolate CLI tests

* gracefully handle no ring-flash-attn

* fix

* fix

* move ring flash attn to extras with flash-attn (#2414)

* removing flash-attn from requirements.txt (in setup.py extras already)

* rename file, delete another

* using field validator instead of model validator

* test fix

* sampler / dataloader refactor

* non-seq2se1 collator fix

* removing print statement

* bugfix

* add SP doc, review comments

* small changes

* review comments, docstrings

* refactors, SP mixin

* small updates

* fix tests

* precommit

* precommit

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-03-21 12:43:55 -04:00
123 changed files with 6023 additions and 1352 deletions

View File

@@ -40,12 +40,24 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: next
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -67,7 +79,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -23,6 +23,7 @@ jobs:
- name: Install dependencies
run: |
python3 -m pip install jupyter quartodoc
python3 -m pip install -e . --no-deps
- name: Build autodoc
run: quartodoc build
- name: Publish to GitHub Pages (and render)

View File

@@ -25,12 +25,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -87,12 +87,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -42,8 +42,7 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
# awaiting vllm#12721
axolotl_extras:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -33,6 +33,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -46,7 +55,7 @@ jobs:
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
pip3 install torch==${{ matrix.pytorch_version }}
- name: Update requirements.txt
run: |
@@ -58,8 +67,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
@@ -73,10 +81,15 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache
run: |
@@ -136,4 +149,4 @@ jobs:
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
modal run cicd.e2e_tests

View File

@@ -63,7 +63,7 @@ jobs:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
@@ -96,6 +96,10 @@ jobs:
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -137,7 +141,7 @@ jobs:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
@@ -171,6 +175,9 @@ jobs:
run: |
axolotl --help
- name: Show HF cache
run: huggingface-cli scan-cache
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
@@ -229,7 +236,7 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
modal run cicd.e2e_tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
@@ -253,7 +260,7 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -276,4 +283,4 @@ jobs:
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
modal run cicd.e2e_tests

View File

@@ -1,3 +1,4 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -40,6 +40,7 @@ quartodoc:
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- title: Trainers
@@ -133,6 +134,7 @@ quartodoc:
- utils.schemas.datasets
- utils.schemas.peft
- utils.schemas.trl
- utils.schemas.multimodal
- utils.schemas.integrations
- utils.schemas.enums
- utils.schemas.utils
@@ -242,6 +244,7 @@ website:
- docs/unsloth.qmd
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -2,4 +2,5 @@
set -e
# only run one test at a time so as not to OOM the GPU
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -0,0 +1,38 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="next"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
pip3 install -U --no-cache-dir pydantic==2.10.6

View File

@@ -170,7 +170,7 @@ axolotl merge-sharded-fsdp-weights config.yml
### evaluate
Evaluates a model's performance using metrics specified in the config.
Evaluates a model's performance (loss etc) on the train and eval datasets.
```bash
# Basic evaluation
@@ -197,6 +197,8 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
@@ -235,7 +237,7 @@ Create a cloud config YAML with your Modal settings:
```yaml
# cloud_config.yml
provider: modal
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu_count: 1 # Number of GPUs to use
timeout: 86400 # Maximum runtime in seconds (24 hours)
branch: main # Git branch to use (optional)
@@ -248,7 +250,7 @@ volumes: # Persistent storage volumes
- name: axolotl-artifacts
mount: /workspace/artifacts
env: # Environment variables
secrets: # Secrets to inject
- WANDB_API_KEY
- HF_TOKEN
```
@@ -274,15 +276,27 @@ axolotl lm-eval config.yml --cloud cloud_config.yml
### Cloud Configuration Options
```yaml
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
provider: # compute provider, currently only `modal` is supported
gpu: # GPU type to use
gpu_count: # Number of GPUs (default: 1)
memory: # RAM in GB (default: 128)
timeout: # Maximum runtime in seconds
timeout_preprocess: # Preprocessing timeout
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
env: # Environment variables to pass
secrets: # Secrets to inject
branch: # Git branch to use
docker_tag: # Custom Docker image tag
volumes: # List of persistent storage volumes
# Environment variables to pass. Can be specified in two ways:
# 1. As a string: Will load the value from the host computer's environment variables
# 2. As a key-value pair: Will use the specified value directly
# Example:
# env:
# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR
# - {CUSTOM_VAR: "value"} # Uses "value" directly
env:
# Secrets to inject. Same input format as `env` but for sensitive data.
secrets:
# - HF_TOKEN
# - WANDB_API_KEY
```

View File

@@ -238,10 +238,10 @@ simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
@@ -320,9 +320,13 @@ total_num_tokens:
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
# Use batch flattening for speedups when not using sample_packing
batch_flattening:
@@ -354,7 +358,27 @@ lora_target_modules:
# - down_proj
# - up_proj
lora_target_linear: # If true, will target all linear modules
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
peft_layers_to_transform:
# Optional[bool]. Whether to use DoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
peft_use_dora:
# Optional[bool]. Whether to use RSLoRA.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
peft_use_rslora:
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
peft_layer_replication:
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
# How to initialize LoRA weights. Default to True which is MS original implementation.
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
peft_init_lora_weights:
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
@@ -466,6 +490,7 @@ auto_find_batch_size: # Optional[bool]
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
@@ -485,7 +510,8 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
@@ -506,36 +532,58 @@ lr_div_factor: # Learning rate div factor
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
#
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# in the examples/ for your model and fine-tuning use case.
#
# Valid values for 'optimizer' include:
# - adamw_hf
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_torch_npu_fused
# - adamw_apex_fused
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - adamw_torch_4bit
# - ademamix
# - sgd
# - adagrad
# - adamw_bnb_8bit
# - adamw_8bit # alias for adamw_bnb_8bit
# - ademamix_8bit
# - lion_8bit
# - lion_32bit
# - paged_adamw_32bit
# - paged_adamw_8bit
# - paged_ademamix_32bit
# - paged_ademamix_8bit
# - paged_lion_32bit
# - paged_lion_8bit
# - rmsprop
# - rmsprop_bnb
# - rmsprop_bnb_8bit
# - rmsprop_bnb_32bit
# - galore_adamw
# - galore_adamw_8bit
# - galore_adafactor
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
# - lomo
# - adalomo
# - grokadamw
# - schedule_free_adamw
# - schedule_free_sgd
# - apollo_adamw
# - apollo_adamw_layerwise
#
# Additional custom optimizers include:
# - optimi_adamw
# - ao_adamw_8bit
# - ao_adamw_fp8
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:
@@ -564,29 +612,42 @@ max_grad_norm:
# currently only supported on Llama and Mistral
neftune_noise_alpha:
# Whether to bettertransformers
# Optional[bool]. Whether to bettertransformers
flash_optimum:
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
# Note: Only one of the following attention patches can be used at a time.
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention:
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
# Optional[bool]. Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention:
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
s2_attention:
# Optional[bool]. Whether to use low_cpu_mem_usage
low_cpu_mem_usage:
# Resume from a specific checkpoint dir
# Optional[str]. Resume from a specific checkpoint dir
resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false
## Multimodal section
# int | tuple[int, int] | None . Size to resize images to, width x height.
# Will read from model/processor config if not set.
image_size:
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
image_resize_algorithm: 'bilinear'
## End of multimodal section
# Don't mess with this, it's here for accelerate and torchrun
local_rank:
@@ -627,6 +688,9 @@ ddp_broadcast_buffers:
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
# Must evenly divide the number of KV heads in your model.
heads_k_stride: 1
# Path to torch distx for optim 'adamw_anyprecision'
torchdistx_path:

View File

@@ -103,8 +103,7 @@ This uses the same tags as the [`main` image](#sec-main-tags).
- `JUPYTER_DISABLE`: Disable Jupyter lab.
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
- `PUBLIC_KEY`: Add a public key for the SSH service.
- `SSH_KEY`: Add a private key for the SSH service.
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
#### Volume mounts

View File

@@ -35,12 +35,22 @@ description: Frequently asked questions
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
> ```yaml
> special_tokens:
> # str. If you're not sure, set to same as `eos_token`.
> pad_token: "..."
> ```
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -17,6 +17,7 @@ We currently support several common model architectures, including (but not limi
- `qwen2`
- `gemma`
- `gemma2`
- `gemma3`
<details>

View File

@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -66,6 +67,28 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).

View File

@@ -1,28 +1,171 @@
# MultiModal / Vision Language Models (BETA)
---
title: MultiModal / Vision Language Models (BETA)
format:
html:
toc: true
toc-depth: 3
---
### Supported Models
## Supported Models
- Mllama, i.e. llama with vision models
- [Mllama](#sec-mllama)
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Gemma-3](#sec-gemma-3)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
### Usage
## Usage
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
Multimodal support is limited and doesn't have full feature parity.
Here are the hyperparams you'll need to use to finetune a multimodal model.
```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true
chat_template: llama3_2_vision
skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal
chat_template: # see in next section
# example dataset
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false
# only finetune the Language model, leave the vision model and vision tower frozen
# (optional) if doing lora, only finetune the Language model,
# leave the vision model and vision tower frozen
# load_in_8bit: true
adapter: lora
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# (optional) if you want to resize images to a set size
image_size: 512
image_resize_algorithm: bilinear
```
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
::: {.callout-warning}
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
:::
### Mllama {#sec-mllama}
```yaml
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
chat_template: llama3_2_vision
```
### Pixtral {#sec-pixtral}
```yaml
base_model: mistralai/Pixtral-12B-2409
chat_template: pixtral
```
### Llava-1.5 {#sec-llava-15}
```yaml
base_model: llava-hf/llava-1.5-7b-hf
chat_template: llava
```
### Mistral-Small-3.1 {#sec-mistral-small-31}
```yaml
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken
```
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}
The Gemma3-1B model is a text-only model, so please train as regular text model.
:::
For multi-modal 4B/12B/27B models, use the following config:
```yaml
base_model: google/gemma-3-4b-it
chat_template: gemma3
```
### Qwen2-VL {#sec-qwen2-vl}
```yaml
base_model: Qwen/Qwen2-VL-7B-Instruct
chat_template: qwen2_vl
```
### Qwen2.5-VL {#sec-qwen25-vl}
```yaml
base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
- A message is a list of `role` and `content`.
- `role` can be `system`, `user`, `assistant`, etc.
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
::: {.callout-note}
For backwards compatibility:
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
- If `content` is a string, it will be converted to a list with `type` as `text`.
:::
::: {.callout-tip}
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
- `"path": "/path/to/image.jpg"`
- `"url": "https://example.com/image.jpg"`
- `"base64": "..."`
- `"image": PIL.Image`
:::
Here is an example of a multi-modal dataset:
```json
[
{
"messages": [
{
"role": "system",
"content": [
{"type": "text", "text": "You are a helpful assistant."}
]
},
{
"role": "user",
"content": [
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
{"type": "text", "text": "Describe this image in detail."}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "The image is a bee."}
]
}
]
}
]
```

View File

@@ -502,9 +502,48 @@ The input format is a simple JSON input with customizable fields based on the ab
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
using 4 GPUs - 2 for training, and 2 for vLLM:
::: {.callout-important}
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
:::
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
dtype: auto
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
rl: grpo
trl:
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
```
```bash
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
```
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
```bash
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
```
#### Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
For example, to load OpenAI's GSM8K and use a random reward for completions:
```python
# rewards.py
@@ -530,8 +569,6 @@ trl:
beta: 0.001
max_completion_length: 256
use_vllm: True
vllm_device: auto
vllm_gpu_memory_utilization: 0.15
num_generations: 4
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
reward_weights: [1.0]

View File

@@ -25,6 +25,8 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
@@ -58,11 +60,16 @@ To use sequence parallelism, you need:
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
...
```

View File

@@ -0,0 +1,71 @@
base_model: CohereForAI/c4ai-command-r7b-12-2024
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
chat_template: cohere
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,79 @@
base_model: google/gemma-3-1b-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
strict: false
# huggingface repo
chat_template: gemma3
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,66 @@
base_model: google/gemma-3-4b-it
strict: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -0,0 +1,68 @@
base_model: google/gemma-3-4b-it
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -19,7 +19,6 @@ val_set_size: 0.0
output_dir: ./outputs/lora-out
dataset_exact_deduplication: true
test_value: true
sequence_len: 4096
sample_packing: true

View File

@@ -0,0 +1,80 @@
base_model: meta-llama/Llama-3.2-1B
# optionally might have model_type or tokenizer_type
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out
test_value: true
sequence_len: 4096
sample_packing: true
sample_packing_sequentially: true
curriculum_sampling: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -0,0 +1,63 @@
base_model: llava-hf/llava-1.5-7b-hf
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llava
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -0,0 +1,66 @@
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
processor_type: AutoProcessor
strict: false
load_in_8bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: mistral_v7_tekken
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 2048
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,65 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: pixtral
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <pad>

View File

@@ -0,0 +1,63 @@
base_model: Qwen/Qwen2-VL-7B-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen2_vl
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,23 +1,23 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.3
bitsandbytes==0.45.4
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.3
liger-kernel==0.5.5
# END section
packaging==23.2
peft==0.15.0
transformers==4.49.0
transformers==4.50.3
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.4.1
deepspeed==0.16.4
trl==0.15.1
datasets==3.5.0
deepspeed==0.15.4
trl==0.16.0
optimum==1.16.2
hf_transfer

View File

@@ -1,315 +0,0 @@
accelerate==0.34.1
addict==2.4.0
aiofiles==23.2.1
aiohttp==3.9.0
aiosignal==1.3.1
aiostream==0.5.2
alembic==1.13.1
annotated-types==0.6.0
annoy==1.17.3
ansible==6.7.0
ansible-core==2.13.13
ansible-vault==2.1.0
anyio==3.7.1
appdirs==1.4.4
art==6.0
asgiref==3.7.2
async-timeout==4.0.2
attrdict==2.0.1
attrs==22.2.0
awscli==1.32.75
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
backoff==2.2.1
base58==2.1.1
beartype==0.17.2
bitnet==0.2.1
bitsandbytes==0.42.0
bittensor==6.7.0
black==23.7.0
blinker==1.7.0
boto3==1.34.75
botocore==1.34.75
cachetools==5.3.3
cachy==0.1.1
certifi==2023.7.22
cffi==1.16.0
cfgv==3.3.1
chai-guanaco==1.2.4
charset-normalizer==3.2.0
cleo==0.6.8
click==8.1.7
cloudpickle==2.0.0
cohere==4.11.2
colorama==0.4.4
coloredlogs==15.0.1
CoLT5-attention==0.10.20
contextlib2==21.6.0
contourpy==1.2.0
cryptography==41.0.3
cycler==0.12.1
cytoolz==0.12.3
databricks-cli==0.18.0
dataclasses-json==0.5.7
datasets==2.11.0
ddt==1.6.0
decorator==5.1.1
deepspeed==0.15.0
# Editable Git install with no remote (dialogpt==0.1)
-e /Users/wing/Projects/ml/dialogpt/src
dill==0.3.6
distlib==0.3.6
docker==7.0.0
docker-pycreds==0.4.0
docstring-parser==0.15
docutils==0.16
ecdsa==0.18.0
einops==0.7.0
einops-exts==0.0.4
einx==0.1.3
entrypoints==0.4
eth-hash==0.6.0
eth-keys==0.5.0
eth-typing==4.0.0
eth-utils==2.3.1
evaluate==0.4.0
exceptiongroup==1.1.1
fastapi==0.109.2
fastcore==1.5.29
ffmpy==0.4.0
filelock==3.12.2
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
fire==0.5.0
first==2.0.2
flake8==7.0.0
Flask==3.0.1
fonttools==4.47.2
frozendict==2.4.1
frozenlist==1.3.3
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
fsspec==2023.6.0
fuzzywuzzy==0.18.0
gitdb==4.0.10
GitPython==3.1.31
google-pasta==0.2.0
gradio==4.42.0
gradio_client==1.3.0
greenlet==2.0.2
grpclib==0.4.7
gunicorn==21.2.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==0.17.3
httpx==0.24.1
huggingface-hub==0.23.4
humanfriendly==10.0
hyperframe==6.0.1
identify==2.5.24
idna==3.4
immutables==0.20
importlib-metadata==6.7.0
importlib-resources==6.1.1
inflection==0.5.1
iniconfig==2.0.0
itsdangerous==2.1.2
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.3.2
jsonlines==3.1.0
jsonschema==2.6.0
kiwisolver==1.4.5
langchain==0.0.144
Levenshtein==0.24.0
libcst==1.1.0
liger-kernel==0.0.0
lion-pytorch==0.1.2
llama-cpp-python==0.1.36
llvmlite==0.40.1
local-attention==1.9.0
loguru==0.7.0
Mako==1.3.2
Markdown==3.5.2
markdown-it-py==3.0.0
markdown2==2.4.10
MarkupSafe==2.1.2
marshmallow==3.19.0
marshmallow-enum==1.5.1
matplotlib==3.8.2
mccabe==0.7.0
mdurl==0.1.2
MEGABYTE-pytorch==0.0.7
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
mlflow==2.10.0
modal==0.62.77
more-itertools==10.2.0
mpmath==1.2.1
msgpack==1.0.7
msgpack-numpy-opentensor==0.5.0
multidict==6.0.4
multiprocess==0.70.14
munch==2.5.0
mypy==1.3.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
netaddr==0.10.1
networkx==3.0rc1
nh3==0.2.14
nodeenv==1.8.0
nomic==2.0.2
numba==0.57.1
numexpr==2.8.4
numpy==1.24.4
oauthlib==3.2.2
openai==0.27.4
openapi==1.1.0
openapi-schema-pydantic==1.2.4
optimum==1.8.6
orjson==3.10.7
packaging==23.1
pandas==2.0.0
parameterized==0.9.0
password-strength==0.0.3.post2
pastel==0.1.1
pathos==0.3.0
pathspec==0.11.1
pathtools==0.1.2
peft==0.11.1
pendulum==3.0.0
Pillow==9.5.0
pip-tools==1.11.0
platformdirs==3.2.0
pluggy==1.4.0
poetry==0.7.1
pox==0.3.2
ppft==1.7.6.6
pre-commit==3.3.2
prettytable==3.10.0
prompt-toolkit==3.0.39
protobuf==3.20.2
protobuf3-to-dict==0.1.5
psutil==5.9.5
psycopg==3.1.18
PuLP==2.8.0
py==1.11.0
py-bip39-bindings==0.1.11
py-cpuinfo==9.0.0
py-ed25519-zebra-bindings==1.0.1
py-sr25519-bindings==0.2.0
pyarrow==11.0.0
pyasn1==0.6.0
pycodestyle==2.11.1
pycparser==2.21
pycryptodome==3.20.0
pydantic==2.5.3
pydantic_core==2.14.6
pydub==0.25.1
pyfiglet==0.8.post1
pyflakes==3.2.0
Pygments==2.15.1
PyJWT==2.8.0
pylev==1.4.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==2.4.7
pyrsistent==0.14.11
pytest==8.0.2
pytest-asyncio==0.23.4
python-dateutil==2.8.2
python-dotenv==1.0.1
python-Levenshtein==0.24.0
python-multipart==0.0.9
pytz==2023.3
PyYAML==6.0.1
querystring-parser==1.2.4
rapidfuzz==3.6.1
regex==2023.6.3
requests==2.31.0
requests-toolbelt==0.8.0
resolvelib==0.8.1
responses==0.18.0
retry==0.9.2
rich==13.7.0
rsa==4.7.2
ruff==0.6.3
s3transfer==0.10.1
safetensors==0.4.5
sagemaker==2.148.0
scalecodec==1.2.7
schedulefree==1.2.1
schema==0.7.5
scikit-learn==1.4.0
scipy==1.9.3
seaborn==0.13.2
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==1.19.1
setproctitle==1.3.2
shellingham==1.5.4
shortuuid==1.0.11
shtab==1.6.5
sigtools==4.0.1
six==1.16.0
skypilot==0.4.1
smdebug-rulesconfig==1.0.1
smmap==5.0.0
sniffio==1.3.0
SQLAlchemy==1.4.47
sqlparse==0.4.4
starlette==0.36.3
substrate-interface==1.5.2
svgwrite==1.4.3
sympy==1.11.1
synchronicity==0.6.7
tabulate==0.9.0
tblib==1.7.0
tenacity==8.2.2
tensor-parallel==2.0.0
termcolor==2.2.0
text2art==0.2.0
threadpoolctl==3.2.0
tiktoken==0.6.0
time-machine==2.14.1
timm==0.9.16
tokenizers==0.19.1
tokenmonster==1.1.12
toml==0.9.6
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.0
torchdata==0.6.1
torchdiffeq==0.2.3
TorchFix==0.4.0
torchtext==0.15.2
torchvision==0.17.0
tqdm==4.66.2
transformers==4.44.2
trl==0.9.6
typer==0.12.5
types-certifi==2021.10.8.3
types-requests==2.31.0.20240125
types-setuptools==69.0.0.20240125
types-toml==0.10.8.7
typing==3.7.4.3
typing-inspect==0.8.0
typing_extensions==4.9.0
tyro==0.5.18
tzdata==2023.3
unique-names-generator==1.0.2
urllib3==2.2.2
uvicorn==0.22.0
vector_quantize_pytorch==1.14.1
virtualenv==20.23.0
voyager==2.0.2
wandb==0.16.2
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.6
websocket-client==1.7.0
websockets==12.0
Werkzeug==3.0.1
wonderwords==2.2.0
xxhash==3.2.0
yarl==1.8.2
zetascale==2.2.7
zipp==3.15.0

View File

@@ -10,15 +10,13 @@ from pathlib import Path
from setuptools import find_packages, setup
def parse_requirements():
def parse_requirements(extras_require_map):
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"deepspeed" in line or "mamba-ssm" in line or "lion-pytorch" in line
)
is_extras = "deepspeed" in line or "mamba-ssm" in line
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
@@ -69,6 +67,7 @@ def parse_requirements():
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -88,7 +87,7 @@ def parse_requirements():
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
return _install_requires, _dependency_links, extras_require_map
def get_package_version():
@@ -105,7 +104,50 @@ def get_package_version():
return version_
install_requires, dependency_links = parse_requirements()
extras_require = {
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": [
"flash-attn==2.7.4.post1",
"ring-flash-attn>=0.1.4",
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.15.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"galore": [
"galore_torch",
],
"apollo": [
"apollo-torch",
],
"optimizers": [
"galore_torch",
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
}
install_requires, dependency_links, extras_require_build = parse_requirements(
extras_require
)
setup(
version=get_package_version(),
@@ -118,40 +160,5 @@ setup(
"axolotl=axolotl.cli.main:main",
],
},
extras_require={
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
"causal_conv1d",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": [
"galore_torch",
],
"optimizers": [
"galore_torch",
"lion-pytorch==0.1.2",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
"vllm": [
"vllm==0.7.2",
],
},
extras_require=extras_require_build,
)

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.8.0.dev0"
__version__ = "0.8.0"

View File

@@ -35,6 +35,55 @@ class TrainerCliArgs:
num_processes: Optional[int] = field(default=None)
@dataclass
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field(
default=1,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: str = field(
default="0.0.0.0", # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(
default=None,
metadata={
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
"out-of-memory (OOM) errors during initialization."
},
)
dtype: Optional[str] = field(
default=None,
metadata={
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
enable_prefix_caching: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature."
},
)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""

View File

@@ -56,7 +56,7 @@ def do_inference(
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
@@ -151,7 +151,7 @@ def do_inference_gradio(
"""
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
@@ -256,7 +256,7 @@ def do_cli(
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -14,7 +14,12 @@ import yaml
from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
@@ -23,6 +28,7 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.cli.vllm_serve import do_vllm_serve
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.schemas.config import AxolotlInputConfig
@@ -316,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)
cli.add_command(lm_eval)

View File

@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
"""
print_axolotl_text_art()
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
@@ -44,6 +44,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
@@ -71,8 +74,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
sequence_parallel_degree=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,
**kwargs,
)
@@ -83,13 +88,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
f"Target directory for merge: `{parsed_cfg.lora_model_dir}` does not exist."
)
parsed_cfg.load_in_4bit = False
parsed_cfg.load_in_8bit = False
parsed_cfg.flash_attention = False
parsed_cfg.deepspeed = None
parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg)

View File

@@ -17,6 +17,7 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -33,6 +34,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
print_axolotl_text_art()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:

View File

@@ -13,11 +13,16 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.models import load_model, load_processor, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)
@@ -295,9 +300,13 @@ def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Args:
@@ -305,7 +314,7 @@ def load_model_and_tokenizer(
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
@@ -313,4 +322,9 @@ def load_model_and_tokenizer(
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor

View File

@@ -0,0 +1,55 @@
"""
CLI to start the vllm server for online RL
"""
from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
def do_vllm_serve(
config: Union[Path, str],
cli_args: dict,
):
"""
Starts the VLLM server for serving LLM models used for online RL
Args
:param cfg: Parsed doct of the YAML config
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
Returns:
process_id: the process id of the started VLLM server
"""
cfg = load_cfg(config)
model = cfg.base_model
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
)
dtype = cli_args.get("dtype") or cfg.vllm.dtype
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
enable_prefix_caching = (
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
)
vllm_script_args = ScriptArguments(
model,
tensor_parallel_size=tensor_parallel_size,
host=host,
port=port,
gpu_memory_utilization=gpu_memory_utilization,
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
)
vllm_serve_main(vllm_script_args)

View File

@@ -60,6 +60,7 @@ from axolotl.core.training_args import (
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
@@ -68,7 +69,6 @@ from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
SaveModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
log_prediction_callback_factory,
@@ -248,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks
@@ -525,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
# handle ddp
ddp_find_unused_parameters = None
if self.cfg.ddp:
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
ddp_find_unused_parameters
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
report_to = []
@@ -747,6 +752,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.accelerator_config
)
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
@@ -890,8 +901,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
kwargs["processing_strategy"] = get_processing_strategy(
self.processor,
training_args.chat_template,
self.cfg.chat_template,
image_size=training_args.image_size,
image_resize_algorithm=training_args.image_resize_algorithm,
)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
@@ -925,7 +941,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks.append(SaveModelCallback())
return callbacks

View File

@@ -8,12 +8,11 @@ import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Literal
from typing import Literal
import datasets
import torch
from datasets import Dataset
from torch import nn
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -28,6 +27,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
OptimizerMixin,
RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
@@ -40,7 +40,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
@@ -112,6 +114,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)
@@ -589,27 +592,3 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trai
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
loss = super().training_step(model, inputs, num_items_in_batch)
return loss

View File

@@ -13,7 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.core.trainers.mixins import SchedulerMixin
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""

View File

@@ -40,18 +40,15 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_device"] = (
trl.vllm_device if trl.vllm_device else "auto"
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
trl.vllm_gpu_memory_utilization
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
trl.vllm_guided_decoding_regex
)
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
if trl.num_generations:
grpo_args_kwargs["num_generations"] = trl.num_generations
@@ -70,6 +67,25 @@ class GRPOStrategy:
if trl.reward_weights:
grpo_args_kwargs["reward_weights"] = trl.reward_weights
if trl.scale_rewards is not None:
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
if trl.temperature is not None:
grpo_args_kwargs["temperature"] = trl.temperature
if trl.top_p is not None:
grpo_args_kwargs["top_p"] = trl.top_p
if trl.top_k is not None:
grpo_args_kwargs["top_k"] = trl.top_k
if trl.min_p is not None:
grpo_args_kwargs["min_p"] = trl.min_p
if trl.repetition_penalty is not None:
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
if trl.num_iterations is not None:
grpo_args_kwargs["num_iterations"] = trl.num_iterations
if trl.epsilon is not None:
grpo_args_kwargs["epsilon"] = trl.epsilon
return grpo_args_kwargs
@classmethod

View File

@@ -2,108 +2,68 @@
Axolotl GRPO trainer
"""
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel
from trl import GRPOConfig, GRPOTrainer
from trl.models import unwrap_model_for_generation
from contextlib import nullcontext
from axolotl.core.trainers.base import SchedulerMixin
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
if is_deepspeed_available():
import deepspeed
# mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
_tag_names = ["trl", "grpo", "axolotl"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# pylint: disable=access-member-before-definition
# Enable gradient checkpointing if requested
if kwargs["args"].gradient_checkpointing:
# Ensure use_cache is disabled
if hasattr(self.model, "config"):
self.model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(self.model) and hasattr(
self.model.base_model, "gradient_checkpointing_enable"
):
self.model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
elif hasattr(self.model, "gradient_checkpointing_enable"):
self.model.gradient_checkpointing_enable()
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
# pylint: enable=access-member-before-definition
def _enable_gradient_checkpointing(
self, model: PreTrainedModel, args: GRPOConfig
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# pylint: disable=unused-argument,redefined-builtin
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if use_reentrant:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
return model
# pylint: enable=unused-argument,redefined-builtin
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model,
self.accelerator,
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
) as unwrapped_model:
if is_compiled_module(unwrapped_model):
unwrapped_model = (
unwrapped_model._orig_mod # pylint: disable=protected-access
)
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", ""): v
for k, v in state_dict.items()
}
# Remove values with adapter prefix (example: "_lora")
state_dict = {
k: v
for k, v in state_dict.items()
if unwrapped_model.prefix not in k
}
# When module to save, remove its prefix and discard the original module
state_dict = {
k.replace("modules_to_save.default.", ""): v
for k, v in state_dict.items()
if "original_module" not in k
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = (
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
)
llm_model.load_weights(state_dict.items())
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()

View File

@@ -4,5 +4,6 @@
# flake8: noqa
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin

View File

@@ -0,0 +1,67 @@
"""
Temporary fix/override for bug in resume from checkpoint
See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
import numpy as np
import torch
from transformers import Trainer, is_torch_npu_available
from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""
mixin for method override to load RNG states from a checkpoint
"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
return
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
LOG.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(rng_file):
LOG.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
# which requires allowlisted classes when loading with weights_only=True.
with safe_globals():
checkpoint_rng_state = torch.load(rng_file) # nosec B614
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
if torch.cuda.is_available():
set_rng_state_for_device(
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
)
if is_torch_npu_available():
set_rng_state_for_device(
"NPU", torch.npu, checkpoint_rng_state, is_distributed
)

View File

@@ -7,6 +7,7 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
@@ -129,3 +130,53 @@ class SequenceParallelMixin:
)
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
def training_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
num_items_in_batch: int | None = None,
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform training step for.
inputs: Dictionary mapping.
"""
# Set up sequence parallelism for this step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal training step
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
def prediction_step(
self,
model: nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""
Perform a prediction step on a batch of inputs. Overrides the
`transformers.trainer.Trainer` method to handle sequence parallelism if
enabled.
Args:
model: Model to perform prediction step for.
inputs: Dictionary mapping of inputs.
prediction_loss_only: Whether to return only the loss.
ignore_keys: Keys to ignore in the inputs.
Returns:
Tuple of (loss, logits, labels).
"""
# Set up sequence parallelism for this prediction step if enabled
if self.args.sequence_parallel_degree > 1:
self._update_ring_flash_attn_params(inputs)
# Proceed with normal prediction step
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore

View File

@@ -1,5 +1,7 @@
"""Module for TRL PPO trainer"""
from typing import Literal, Union
import torch
from tqdm import tqdm
from trl import (
@@ -11,6 +13,7 @@ from trl import (
RewardTrainer,
)
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -72,15 +75,87 @@ class TRLPPOTrainer(PPOTrainer):
)
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
metrics = {}
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
chosen_rewards
).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
rejected_rewards
).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
reward_accuracies
).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
)
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
policy_rejected_logits.detach().mean()
).mean()
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
policy_chosen_logits.detach().mean()
).mean()
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
)
metrics[f"{prefix}log_odds_ratio"] = (
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
)
metrics[f"{prefix}log_odds_chosen"] = (
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
)
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -88,15 +163,89 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
policy_chosen_logps,
policy_rejected_logps,
)
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = (
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
)
metrics[f"{prefix}rewards/rejected"] = (
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
)
metrics[f"{prefix}rewards/accuracies"] = (
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
)
metrics[f"{prefix}rewards/margins"] = (
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
.mean()
.item()
)
metrics[f"{prefix}logps/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logps/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logps)
.detach()
.mean()
.item()
)
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}logits/chosen"] = (
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
.mean()
.item()
)
metrics[f"{prefix}nll_loss"] = (
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
)
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss
return loss, metrics
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
@@ -104,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""

View File

@@ -5,6 +5,7 @@ extra axolotl specific training args
from dataclasses import dataclass, field
from typing import Optional
from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@@ -33,6 +34,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
@@ -212,6 +219,20 @@ class AxolotlTrainingMixins:
metadata={"help": "The number of workers to use in sequence parallelism"},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):

View File

@@ -15,6 +15,7 @@ from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -159,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
del model
del tokenizer
cleanup_distributed()
return all_metrics

View File

@@ -1,6 +1,6 @@
# Cut Cross Entropy
Cut Cross Entropy reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
Cut Cross Entropy (CCE) reduces VRAM usage through optimization on the cross-entropy operation during loss calculation.
See https://github.com/apple/ml-cross-entropy
@@ -29,6 +29,20 @@ plugins:
cut_cross_entropy: true
```
## Supported Models
- llama
- phi3
- gemma
- gemma2
- gemma3
- gemma3_text
- mistral
- mistral3
- qwen2
- cohere
- cohere2
## Citation
```bib

View File

@@ -25,8 +25,8 @@ import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import zero_only
from ...utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
@@ -72,7 +72,9 @@ class CutCrossEntropyPlugin(BasePlugin):
if cfg.cut_cross_entropy:
self._check_requirements()
from cut_cross_entropy.transformers import cce_patch
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
cce_patch,
)
with zero_only():
LOG.info(

View File

@@ -0,0 +1,201 @@
"""Cohere and Cohere2 CCE patch."""
# This patch is based off transformers 4.50.0.
# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM.
# It scales the hidden states by the logit scale in advance instead of the logits as the
# operation is done internally and should be mathematically equivalent.
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
_CONFIG_FOR_DOC,
COHERE_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>> from transformers import AutoTokenizer, CohereForCausalLM
>> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
>> prompt = "Hey, are you conscious? Can you talk to me?"
>> inputs = tokenizer(prompt, return_tensors="pt")
>> # Generate
>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
# scale weight by logit_scale in-place of logits
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight * self.logit_scale,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
logits = logits * self.logit_scale # main diff from Llama
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_cohere(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.cohere import modeling_cohere
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_cohere.CohereForCausalLM
), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_cohere.CohereForCausalLM.forward = cce_forward
return None
def patch_cohere2(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.cohere2 import modeling_cohere2
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_cohere2.Cohere2ForCausalLM
), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,175 @@
"""Gemma CCE patch"""
# This patch is based off transformers 4.50.0.
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
_CONFIG_FOR_DOC,
GEMMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_gemma(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma import modeling_gemma
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma.GemmaForCausalLM
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_gemma.GemmaForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,459 @@
"""Gemma2 and Gemma3 (text and multimodal) CCE patch."""
# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29
# and updated for transformers 4.50.0.
# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works
# with both gemma3 (text and multimodal) models.
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
from torch import nn
from transformers.cache_utils import Cache, HybridCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Gemma3ForCausalLM
>>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**loss_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**loss_kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
is_training = token_type_ids is not None and labels is not None
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids # type: ignore
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore
)
cache_position = torch.arange( # type: ignore
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_index,
dtype=torch.long,
device=inputs_embeds.device,
)
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where( # type: ignore
input_ids == self.pad_token_id, self.config.ignore_index, labels
)
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
attention_mask,
token_type_ids,
past_key_values,
cache_position,
inputs_embeds,
is_training,
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**lm_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
softcap=getattr(self.config, "final_logit_softcapping", None),
**lm_kwargs,
)
else:
logits = hidden_states
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(
logits.device
)
shift_logits = shift_logits[
shift_attention_mask.to(logits.device) != 0
].contiguous()
shift_labels = shift_labels[
shift_attention_mask.to(shift_labels.device) != 0
].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_gemma2(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma2 import modeling_gemma2
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma2.Gemma2ForCausalLM
), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward
return None
def patch_gemma3_text(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma3 import modeling_gemma3
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma3.Gemma3ForCausalLM
), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
return None
def patch_gemma3(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma3 import modeling_gemma3
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration
), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the causal model to enable deferred logits calculation
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal model to enable deferred logits calculation
modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,392 @@
"""Mistral and Mistral3 CCE patch."""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.mistral3.modeling_mistral3 import (
Mistral3CausalLMOutputWithPast,
)
from transformers.models.mistral.modeling_mistral import (
_CONFIG_FOR_DOC,
MISTRAL_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] | None = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,
pixel_values: torch.FloatTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, list[int]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: torch.Tensor | None = None,
**lm_kwargs,
) -> Union[Tuple, Mistral3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration
>>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
>>> prompt = "<s>[INST][IMG]What is the image?[/INST]"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is the image?The image depicts two cats lying on a pink blanket."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_feature_layer
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
image_sizes=image_sizes,
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**lm_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
**lm_kwargs,
)
else:
logits = hidden_states
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
logits.device
)
shift_logits = logits[..., :-1, :][
shift_attention_mask.to(logits.device) != 0
].contiguous()
shift_labels = labels[..., 1:][
shift_attention_mask.to(labels.device) != 0
].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1).to(shift_logits.device),
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Mistral3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_mistral(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.mistral import modeling_mistral
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_mistral.MistralForCausalLM
), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_mistral.MistralForCausalLM.forward = cce_forward
return None
def patch_mistral3(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.mistral import modeling_mistral
from transformers.models.mistral3 import modeling_mistral3
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration
), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the causal model to enable deferred logits calculation
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal model to enable deferred logits calculation
modeling_mistral.MistralForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,379 @@
"""Mllama CCE patch."""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.mllama.modeling_mllama import (
MLLAMA_INPUTS_DOCSTRING,
_prepare_cross_attention_mask,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaTextConfig"
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.LongTensor] = None,
cross_attention_mask: Optional[torch.LongTensor] = None,
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
defer_logits_calculation: bool = False,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
defer_logits_calculation (`bool`, *optional*):
If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the
memory overhead of calculating logits using regular lm_head forward pass and to use CCE.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MllamaForCausalLM
>>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision")
>>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision")
>>> prompt = "If I had to write a haiku, it would be:"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
>>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
>>> print(result)
If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
I love the idea of snowflakes gently falling, each one
```
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
cross_attention_states=cross_attention_states,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
loss = None
logits = None
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**loss_kwargs,
)
elif _PATCH_OPTS is not None and defer_logits_calculation:
# defer logits calculation to the ConditionalGeneration forward
logits = hidden_states[:, slice_indices, :]
else:
logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class="MllamaConfig"
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
>>> processor = AutoProcessor.from_pretrained(checkpoint)
>>> prompt = "<|image|>If I had to write a haiku for this one"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> output = model.generate(**inputs, max_new_tokens=15)
>>> prompt_len = inputs.input_ids.shape[-1]
>>> generated_ids = output[:, prompt_len:]
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
>>> print(generated_text)
[', it would be:.\\nA stop sign in Chinatown.\\n']
```
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError(
"`pixel_values` and `cross_attention_states` cannot be provided simultaneously"
)
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(
cross_attention_states
).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size # type: ignore
)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = (
_prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[
:, :, cache_position
]
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
defer_logits_calculation=True, # enable deferred logits calculation
**loss_kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.language_model.lm_head.weight,
labels,
_PATCH_OPTS,
**loss_kwargs,
)
else:
# Temporary fix to calculate the loss in main class, as the model's vocab size may be resized
logits = hidden_states
if labels is not None:
loss = self.loss_function(
logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs
)
if not return_dict:
return (loss,) + outputs if loss is not None else outputs
return CausalLMOutputWithPast(
loss=loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_mllama(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.mllama import modeling_mllama
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_mllama.MllamaForConditionalGeneration
), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
# patch the language model
maybe_model.language_model.forward = MethodType(
cce_forward, maybe_model.language_model
)
return maybe_model
modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal
# patch the causal language model
modeling_mllama.MllamaForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,85 @@
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Cut Cross Entropy patcher"""
import transformers
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
from cut_cross_entropy.transformers.llama import patch_llama
from cut_cross_entropy.transformers.phi3 import patch_phi3
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
patch_cohere,
patch_cohere2,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma
from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma2,
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
patch_mistral,
patch_mistral3,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"llama": patch_llama,
"mllama": patch_mllama,
"phi3": patch_phi3,
"gemma": patch_gemma,
"gemma2": patch_gemma2,
"gemma3": patch_gemma3,
"gemma3_text": patch_gemma3_text,
"mistral": patch_mistral,
"mistral3": patch_mistral3,
"qwen2": patch_qwen2,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
}
def cce_patch(
model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig,
impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT,
reduction: str = "mean",
filter_eps: float | str | None = "auto",
accum_e_fp32: bool = False,
accum_c_fp32: bool = False,
filter_e_grad: bool = True,
filter_c_grad: bool = True,
train_only: bool = False,
) -> TransformersModelT | None:
if isinstance(impl, LinearCrossEntropyImpl):
impl = impl.name.lower()
if impl not in (v.name.lower() for v in LinearCrossEntropyImpl):
raise ValueError(f"Unknown {impl=}")
if isinstance(model_type_or_model, transformers.PreTrainedModel):
model_type = model_type_or_model.config.model_type
elif isinstance(model_type_or_model, transformers.PretrainedConfig):
model_type = model_type_or_model.model_type
else:
model_type = model_type_or_model
patch_options = PatchOptions(
impl=impl,
reduction=reduction,
filter_eps=filter_eps,
accum_e_fp32=accum_e_fp32,
accum_c_fp32=accum_c_fp32,
filter_e_grad=filter_e_grad,
filter_c_grad=filter_c_grad,
train_only=train_only,
)
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type](
model_type_or_model, patch_options
)
raise RuntimeError(f"Unknown model type {model_type}")

View File

@@ -0,0 +1,40 @@
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""Monkeypatch for apply_lce to add softcap."""
import torch
from cut_cross_entropy import linear_cross_entropy
from cut_cross_entropy.transformers.utils import PatchOptions
def apply_lce(
e: torch.Tensor,
c: torch.Tensor,
labels: torch.Tensor,
opts: PatchOptions,
bias: torch.Tensor | None = None,
softcap: float | None = None,
**loss_kwargs,
) -> torch.Tensor:
"""Monkey patch for apply_lce to support softcap kwarg."""
num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
cce_kwargs = opts.to_kwargs()
if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
cce_kwargs["reduction"] = "sum"
else:
num_items_in_batch = None
loss = linear_cross_entropy(
e,
c,
labels.to(e.device),
bias=bias,
shift=True,
softcap=softcap,
**cce_kwargs,
)
if num_items_in_batch is not None:
loss = loss / num_items_in_batch
return loss

View File

@@ -20,6 +20,26 @@ liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Supported Models
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- granite
- jamba
- llama
- mistral
- mixtral
- mllama
- mllama_text_model
- olmo2
- paligemma
- phi3
- qwen2
- qwen2_5_vl
- qwen2_vl
## Citation
```bib

View File

@@ -21,6 +21,7 @@ It is designed to be performant, correct, and light-weight.
import inspect
import logging
import sys
from functools import partial
from axolotl.integrations.base import BasePlugin
@@ -41,11 +42,18 @@ class LigerPlugin(BasePlugin):
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
raise ValueError(
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
@@ -82,6 +90,8 @@ class LigerPlugin(BasePlugin):
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_layer_norm:
modeling_jamba.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
@@ -104,13 +114,51 @@ class LigerPlugin(BasePlugin):
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_layer_norm:
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")

View File

@@ -27,7 +27,7 @@ def get_ring_attn_group() -> dist.ProcessGroup:
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""
Setter for ring attention group on this rank.
@@ -38,13 +38,19 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
RING_ATTN_GROUP = ring_attn_group
def register_ring_attn(sequence_parallel_degree: int):
def register_ring_attn(sequence_parallel_degree: int, heads_k_stride: int | None):
"""
Create ring attention group and substitute flash attn with ring flash attn.
Args:
sequence_parallel_degree: Sequence parallelism factor.
heads_k_stride: Sequence parallelism K head stride size. Passed
through to `ring_flash_attn.substitute_hf_flash_attn`.
"""
if get_ring_attn_group() is not None:
LOG.info("Ring attention already registered, exiting early...")
return
LOG.info(
"Enabling ring attention sequence parallelism: "
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
@@ -84,6 +90,11 @@ def register_ring_attn(sequence_parallel_degree: int):
if rank == 0:
LOG.info(f"Sequence parallel group assignments: {group_assignments}")
if heads_k_stride is None:
heads_k_stride = 1
from ring_flash_attn import substitute_hf_flash_attn
substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_degree)
substitute_hf_flash_attn(
process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride
)

View File

@@ -0,0 +1,238 @@
"""Monkeypatch for gemma3 conditional generation forward to fix loss exploding"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def new_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
>>> prompt = "answer en Where is the cow standing?"
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"answer en Where is the cow standing?\nbeach"
```"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
is_training = token_type_ids is not None and labels is not None
# Replace image id with PAD if the image token is OOV, to avoid index-errors
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_index
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# Merge text and images
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(
self.config.image_token_index,
dtype=torch.long,
device=inputs_embeds.device,
)
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(
-1
)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_image_mask].numel() != image_features.numel()
):
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
# mask out pad-token-ids in labels for BC
if labels is not None and self.pad_token_id in labels:
logger.warning_once(
"`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
"You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
)
labels = torch.where(
input_ids == self.pad_token_id, self.config.ignore_index, labels
)
causal_mask = self._update_causal_mask( # pylint: disable=protected-access
attention_mask,
token_type_ids,
past_key_values,
cache_position,
inputs_embeds,
is_training,
)
outputs = self.language_model(
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
if attention_mask is not None:
# Get the shifted attention mask
shift_attention_mask = attention_mask[:, -logits.shape[1] + 1 :].to(
logits.device
) # +1 for shift
# Filter logits and labels based on attention mask
valid_indices = shift_attention_mask != 0
filtered_logits = logits[..., :-1, :][valid_indices]
filtered_labels = labels[..., 1:][valid_indices.to(labels.device)]
# TODO: do we need to handle num_items_in_batch given we filter the logits and labels?
loss = self.loss_function(
logits=filtered_logits,
labels=None, # we pass shift_labels
shift_labels=filtered_labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
else:
# Standard case without filtering
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**lm_kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def patch_gemma3conditionalgeneration_forward():
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3ForConditionalGeneration,
)
Gemma3ForConditionalGeneration.forward = new_forward

View File

@@ -252,12 +252,38 @@ def apply_lora_kernel_patches(
LOG.setLevel(logging.INFO)
# Choose activation based on model type
activation = model.config.hidden_act
activation = None
text_config = (
model.config.get_text_config()
if hasattr(model.config, "get_text_config")
else model.config
)
if hasattr(text_config, "hidden_act"):
activation = text_config.hidden_act
elif hasattr(text_config, "hidden_activation"):
activation = text_config.hidden_activation
# map activation to supported activation
if "gelu" in activation:
# gemma3 uses gelu_pytorch_tanh
activation = "gelu"
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
layers = []
# check for multimodal models first
if hasattr(model, "language_model"):
layers = model.language_model.model.layers
elif hasattr(model, "model"):
layers = model.model.model.layers
else:
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
# Patch each layer
for layer in model.model.model.layers:
for layer in layers:
# Add QKV, O fallback implementations to start
# These will be overwritten later (if some conditions apply)
layer.self_attn.apply_qkv = types.MethodType(

View File

@@ -22,6 +22,10 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"phi3",
"gemma",
"gemma2",
"gemma3",
"gemma3_text",
"cohere",
"cohere2",
"gemmoe",
"starcoder2",
"deepseek_v2",

View File

@@ -0,0 +1,278 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
from copy import deepcopy
from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor
from transformers import ProcessorMixin
from transformers.image_utils import load_image
class ProcessingStrategy:
"""Base Processing Strategy class"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
self.processor = processor
self.chat_template = chat_template
self.image_token = None
self.image_token_id = None
self.image_size = image_size
self.image_resize_algorithm = (
image_resize_algorithm or Image.Resampling.BILINEAR
)
if hasattr(processor, "image_token"):
self.image_token = processor.image_token
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def __call__(self, examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
list of dicts in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{"role": normalize_role(convo["from"]), "content": convo["value"]}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
result["messages"] = messages
return result
def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
"""Convert regular messages format to Messages format with content type"""
new_messages = []
for message in messages:
if isinstance(message["content"], str):
new_messages.append(
{
"role": message["role"],
"content": [
{
"type": "text",
"text": message["content"],
}
],
}
)
elif isinstance(message["content"], list):
content = message["content"]
new_messages.append(
{
"role": message["role"],
"content": content,
}
)
return new_messages
processed_examples = []
for example in examples:
if not ("messages" in example or "conversations" in example):
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
processed_example = None
if "messages" in example: # OpenAI format
processed_example = example
else: # Legacy format
processed_example = convert_legacy_format(example)
# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
# find the image key if it exists
possible_image_keys = ["images", "image"]
image_key = None
for key in possible_image_keys:
if key in processed_example:
image_key = key
break
# if the image key exists, add the image to the first message
if image_key is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
image_value = processed_example[image_key][0]
# Handle image loading (Image, url, path, base64)
image_value = load_image(image_value)
if self.image_size is not None:
assert hasattr(
image_value, "resize"
), "Image does not have a resize method"
if isinstance(self.image_size, tuple):
image_value = image_value.resize(
self.image_size, self.image_resize_algorithm
)
else:
# Set the padding value; here we use black (0, 0, 0) for RGB images
padding_color = (0, 0, 0)
# When image_size is an int (square target), preserve aspect ratio then pad
# This is to prevent aspect ratio distortion when resizing to square
image_value = ImageOps.pad(
image_value,
(self.image_size, self.image_size),
method=self.image_resize_algorithm,
color=padding_color,
)
# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
ind_to_add = None
for i, content in enumerate(
processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
ind_to_add = i
break
# If an image type is found, add the image to that index
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
else:
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
}
)
processed_examples.append(processed_example)
return processed_examples
def process_labels(self, input_ids: Tensor) -> Tensor:
labels = input_ids.clone()
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels[labels == self.processor.tokenizer.pad_token_id] = -100
# Ignore the image token index in the loss computation (model specific)
labels[labels == self.image_token_id] = -100
return labels
class Qwen2VLProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Qwen2-VL"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image_pad|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
class Gemma3ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Gemma3"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = processor.tokenizer.special_tokens_map["boi_token"]
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
def process_labels(self, input_ids):
labels = input_ids.clone()
# Follows https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora
labels[labels == self.processor.tokenizer.pad_token_id] = -100
labels[labels == self.image_token_id] = -100
labels[labels == 262144] = -100 # corresponds to <image_soft_token>
return labels
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
chat_template_type,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llava",
"mistral_v7_tekken",
"pixtral",
]:
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
raise ValueError(f"Unsupported chat template type: {chat_template_type}")

View File

@@ -411,11 +411,15 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if turn_idx >= len(turns):
raise ValueError(f"Turn index {turn_idx} out of range")
# mistral does not output message if it contains only system message
# mistral/gemma3 does not output message if it contains only system message
if (
turn_idx == 0
and turns[0].get("role") == "system"
and "mistral" in self.tokenizer.name_or_path.lower()
and (
"mistral" in self.tokenizer.name_or_path.lower()
# gemma3 uses gemma tokenizer
or "gemma" in self.tokenizer.name_or_path.lower()
)
):
return -1, -1

View File

@@ -14,6 +14,7 @@ import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
from peft import PeftConfig, PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -26,6 +27,7 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
@@ -156,6 +158,8 @@ def setup_signal_handler(
_model.save_pretrained(
cfg.output_dir, safe_serialization=safe_serialization
)
cleanup_distributed()
sys.exit(0)
_model_weakref = weakref.ref(model)
@@ -302,7 +306,7 @@ def create_model_card(cfg: DictDefault, trainer: Trainer):
model_card_kwarg["dataset_tags"] = dataset_tags
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError):
except (AttributeError, UnicodeDecodeError, OfflineModeIsEnabled):
pass
elif cfg.hub_model_id:
# Defensively push to the hub to ensure the model card is updated
@@ -314,6 +318,7 @@ def save_initial_configs(
tokenizer: PreTrainedTokenizer,
model: PreTrainedModel,
peft_config: PeftConfig | None,
processor: ProcessorMixin | None,
):
"""
Save initial configurations before training.
@@ -341,6 +346,10 @@ def save_initial_configs(
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
if processor:
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
processor.save_pretrained(str(output_dir))
def setup_model_card(cfg: DictDefault):
"""
@@ -408,6 +417,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
ProcessorMixin | None,
]:
"""
Load model, tokenizer, trainer, etc. Helper function to encapsulate the full
@@ -423,6 +433,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
- Model
- Tokenizer
- PEFT config
- Processor
"""
# Load tokenizer, processor and model
model, tokenizer, peft_config, processor = setup_model_and_tokenizer(cfg)
@@ -453,6 +464,7 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
model,
tokenizer,
peft_config,
processor,
)
@@ -469,42 +481,35 @@ def train(
Returns:
Tuple of (model, tokenizer) after training
"""
# Setup model, tokenizer, (causal or RLHF) trainer etc.
# Setup model, tokenizer, (causal or RLHF) trainer, etc.
(
trainer,
model,
tokenizer,
peft_config,
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Determine if we need to resume from a checkpoint
resume_from_checkpoint = determine_resume_checkpoint(cfg)
# Configuration for saving
safe_serialization = cfg.save_safetensors is True
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(
cfg, model, tokenizer, train_dataset, safe_serialization
)
# Save initial configs
save_initial_configs(cfg, tokenizer, model, peft_config)
# Set up signal handler for graceful termination
# Additional setup
save_initial_configs(cfg, tokenizer, model, peft_config, processor)
setup_signal_handler(cfg, model, safe_serialization)
# Set up badges and config info for model card
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# Save the trained model
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
# Create model card
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()
return model, tokenizer, trainer

View File

@@ -816,27 +816,6 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
return control
class SaveModelCallback(TrainerCallback):
"""Callback to save model on train end"""
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs
):
control.should_save = True
return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""

File diff suppressed because one or more lines are too long

View File

@@ -112,6 +112,7 @@ class DataCollatorForSeq2Seq:
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
labels = None
if return_tensors is None:
return_tensors = self.return_tensors
@@ -164,6 +165,8 @@ class DataCollatorForSeq2Seq:
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
if not has_attn_mask:
del features["attention_mask"]
# prepare decoder_input_ids
if (

View File

@@ -2,15 +2,17 @@
Collators for multi-modal chat messages and packing
"""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
from axolotl.processing_strategies import ProcessingStrategy
@dataclass
class MultiModalChatDataCollator(DataCollatorMixin):
@@ -19,11 +21,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
"""
tokenizer: PreTrainedTokenizerBase
processor: ProcessorMixin
return_tensors: str = "pt"
chat_template: Optional[str] = None
processing_strategy: ProcessingStrategy
packing: bool = False
max_images: int = -1
return_tensors: str = "pt"
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
@@ -31,162 +31,62 @@ class MultiModalChatDataCollator(DataCollatorMixin):
if self.packing:
raise ValueError("Packing is currently not supported.")
def torch_call(
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images
)
@staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
def _preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
def _process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
def torch_call(self, examples: list[dict]) -> dict[str, Any]:
return self.process_rows(examples)
def process_rows(
self,
examples: list[dict],
) -> dict[str, Tensor]:
# Preprocess the examples
examples = _preprocess(examples)
examples = self.processing_strategy(examples)
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
# Initialize batch
batch: dict[str, Any] = {}
# Process each example
for example in examples:
# Apply chat template to process the example
# This method requires transformers>=4.49.0
result = self.processing_strategy.processor.apply_chat_template(
example["messages"],
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
for example in examples
]
images = _process_images(examples, max_images=max_images)
# TODO: Check if need handling for len(input_ids) > sequence_len
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# Add the processed tensors to our batch
for key in result.keys():
if key not in batch:
batch[key] = []
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
batch[key].append(result[key].squeeze(0))
# Pad sequences to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"],
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels[labels == image_token_id] = -100
batch["labels"] = labels
if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch
attention_mask = torch.nn.utils.rnn.pad_sequence(
batch["attention_mask"], batch_first=True, padding_value=0
)
# Create the final batch
final_batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
# Process the labels
final_batch["labels"] = self.processing_strategy.process_labels(
final_batch["input_ids"]
)
return final_batch

View File

@@ -13,7 +13,7 @@ from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
from axolotl.utils.models import MULTIMODAL_AUTO_MODEL_MAPPING, load_model_config
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
@@ -158,7 +158,7 @@ def normalize_config(cfg):
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"]
and model_config.model_type in MULTIMODAL_AUTO_MODEL_MAPPING
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
@@ -171,7 +171,6 @@ def normalize_config(cfg):
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config
cfg.model_config_type = model_config.model_type

View File

@@ -6,8 +6,12 @@ from pathlib import Path
from typing import Optional, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HFValidationError
from huggingface_hub import hf_hub_download, snapshot_download
from huggingface_hub.errors import (
HFValidationError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from axolotl.utils.dict import DictDefault
@@ -70,20 +74,25 @@ def load_dataset_w_config(
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
snapshot_download(
repo_id=config_dataset.path,
repo_type="dataset",
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
ignore_patterns=["*"],
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
except (
RepositoryNotFoundError,
RevisionNotFoundError,
FileNotFoundError,
ConnectionError,
HFValidationError,
ValueError,
):
pass
ds_from_cloud = False
@@ -229,7 +238,8 @@ def load_dataset_w_config(
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
else:
elif config_dataset.data_files:
fp: str | list[str] | None = None
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,

View File

@@ -71,8 +71,8 @@ def barrier():
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
Check if the current process is the main process. If not in distributed mode,
always return `True`.
"""
if not is_distributed():
return True
@@ -87,6 +87,18 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
def cleanup_distributed():
"""
Destroy process group if torch distributed is initialized. Called in training early
termination or when training successfully completes.
"""
# Ensure that all operations are completed before destroying the process group
torch.cuda.synchronize()
# Destroy the process group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@contextmanager
def zero_only():
"""

View File

@@ -8,7 +8,7 @@ import math
import os
import types
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
from typing import Any, Dict, Optional, Tuple
import addict
import bitsandbytes as bnb
@@ -25,7 +25,7 @@ from peft import (
prepare_model_for_kbit_training,
)
from torch import nn
from transformers import ( # noqa: F401
from transformers import (
AddedToken,
AutoConfig,
AutoModelForCausalLM,
@@ -34,12 +34,17 @@ from transformers import ( # noqa: F401
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
from transformers.integrations.deepspeed import (
HfTrainerDeepSpeedConfig,
@@ -69,9 +74,13 @@ from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_mod
LOG = logging.getLogger(__name__)
MULTIMODEL_AUTO_MODEL_MAPPING = {
"llava": LlavaForConditionalGeneration,
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
}
@@ -99,9 +108,30 @@ def get_module_class_from_name(module, name):
return None
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
# Set use_cache to False
if hasattr(model_config, "use_cache"):
model_config.use_cache = False
if cfg.is_multimodal:
model_config = model_config.text_config
# For multimodal configs, use_cache is set in the text_config
if hasattr(model_config, "get_text_config"):
text_config = model_config.get_text_config()
if hasattr(text_config, "use_cache"):
text_config.use_cache = False
else:
raise ValueError(
"No text config found for multimodal model. Please raise an Issue with model details."
)
# check if image_size is not set and load image size from model config if available
if (
cfg.image_size is None
and hasattr(model_config, "vision_config")
and hasattr(model_config.vision_config, "image_size")
):
cfg.image_size = model_config.vision_config.image_size
LOG.debug(f"Loaded image size: {cfg.image_size} from model config")
quant_config_exists = (
hasattr(model_config, "quantization_config")
@@ -440,6 +470,31 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
**processor_kwargs,
)
# Attempt to load image size from processor if available
if (
cfg.image_size is None
and hasattr(processor, "size")
and any(dim in processor.size for dim in ["width", "height"])
):
im_width = None
im_height = None
if "width" in processor.size:
im_width = processor.size["width"]
if "height" in processor.size:
im_height = processor.size["height"]
# If both width and height are set, use a tuple
if im_width is not None and im_height is not None:
cfg.image_size = (im_width, im_height)
# If only width is set, use as integer
elif im_width is not None:
cfg.image_size = im_width
# If only height is set, use as integer
elif im_height is not None:
cfg.image_size = im_height
LOG.debug(f"Loaded image size: {cfg.image_size} from processor")
return processor
@@ -476,14 +531,19 @@ class ModelLoader:
# init model config
self.model_config = load_model_config(cfg)
if cfg.is_multimodal:
self.text_model_config = self.model_config.text_config
else:
self.text_model_config = self.model_config
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
# patch gemma3 conditional generation forward before loading plugins
# as it could be overridden by plugins
if self.cfg.model_config_type == "gemma3":
from axolotl.monkeypatch.gemma3 import (
patch_gemma3conditionalgeneration_forward,
)
patch_gemma3conditionalgeneration_forward()
# load any patches from plugins
from axolotl.integrations.base import PluginManager
@@ -558,7 +618,10 @@ class ModelLoader:
# Initialize ring attn for sequence parallelism. This must be done after
# model init but before the first forward pass, since it modifies flash
# attn to use ring comm for SP training across multiple GPUs.
register_ring_attn(self.cfg.sequence_parallel_degree)
register_ring_attn(
sequence_parallel_degree=self.cfg.sequence_parallel_degree,
heads_k_stride=self.cfg.heads_k_stride,
)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
@@ -673,7 +736,7 @@ class ModelLoader:
should be set according to the type of the model.
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODEL_AUTO_MODEL_MAPPING.get(
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
@@ -896,8 +959,6 @@ class ModelLoader:
quantization_config = (
quantization_config or self.model_kwargs["quantization_config"]
)
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = load_sharded_model_quant(
self.base_model,
self.model_config,
@@ -918,9 +979,6 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
# Load model with random initialization if specified
if self.cfg.random_init_weights:
# AutoModel classes support the from_config method
@@ -975,8 +1033,6 @@ class ModelLoader:
and self.model_type != "AutoModelForCausalLM"
and not self.cfg.trust_remote_code
):
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
@@ -992,25 +1048,7 @@ class ModelLoader:
**self.model_kwargs,
)
else:
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(self.text_model_config, "max_seq_len")
and self.text_model_config.max_seq_len
and self.cfg.sequence_len > self.text_model_config.max_seq_len
):
self.text_model_config.max_seq_len = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
elif (
hasattr(self.text_model_config, "max_sequence_length")
and self.text_model_config.max_sequence_length
and self.cfg.sequence_len > self.text_model_config.max_sequence_length
):
self.text_model_config.max_sequence_length = self.cfg.sequence_len
LOG.warning(f"increasing context length to {self.cfg.sequence_len}")
if self.cfg.gptq:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1029,8 +1067,6 @@ class ModelLoader:
_ = _configure_zero3_memory_efficient_loading()
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1194,7 +1230,9 @@ class ModelLoader:
)
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None:
if self.cfg.mean_resizing_embeddings is not None and not (
self.model_config.model_type == "llava"
):
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
else:
@@ -1293,8 +1331,6 @@ class ModelLoader:
requires_grad.append(f"{name}: {param.requires_grad}")
if len(requires_grad) == 0:
LOG.warning("there are no parameters that require gradient updates")
if hasattr(self.model, "config"):
self.model.config.use_cache = False
if self.cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer

View File

@@ -8,11 +8,13 @@ from typing import Any, Iterable, List, Union
import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
@@ -103,6 +105,55 @@ def allocate(
return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
for idx, size in enumerate(lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
total_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
@@ -115,6 +166,7 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
@@ -122,6 +174,7 @@ class MultipackBatchSampler(BatchSampler):
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
assert isinstance(self.lengths, np.ndarray)
@@ -136,6 +189,11 @@ class MultipackBatchSampler(BatchSampler):
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warn(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
self.epoch = epoch
@@ -145,13 +203,21 @@ class MultipackBatchSampler(BatchSampler):
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
if self.sequential:
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
)
else:
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches = [
[

View File

@@ -42,9 +42,11 @@ from axolotl.utils.schemas.model import (
ModelOutputConfig,
SpecialTokensConfig,
)
from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
LOG = logging.getLogger(__name__)
@@ -64,6 +66,7 @@ class AxolotlInputConfig(
LISAConfig,
GradioConfig,
RayConfig,
MultiModalConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
@@ -84,6 +87,9 @@ class AxolotlInputConfig(
trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
)
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
)
reward_model: bool | None = None
process_reward_model: bool | None = None
num_labels: int | None = None
@@ -186,6 +192,7 @@ class AxolotlInputConfig(
sample_packing: bool | None = None
sample_packing_group_size: int | None = 100_000
sample_packing_bin_size: int | None = 200
sample_packing_sequentially: bool | None = None
eval_sample_packing: bool | None = None
pad_to_sequence_len: bool | None = None
curriculum_sampling: bool | None = None
@@ -246,6 +253,7 @@ class AxolotlInputConfig(
val_set_size: float | None = Field(default=0.0)
sequence_parallel_degree: int | None = None
heads_k_stride: int | None = None
special_tokens: SpecialTokensConfig | None = None
tokens: list[str] | None = None
@@ -1106,7 +1114,7 @@ class AxolotlInputConfig(
@field_validator("sequence_parallel_degree", mode="before")
@classmethod
def check_sequence_parallel_config(cls, value, info):
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1
@@ -1127,6 +1135,17 @@ class AxolotlInputConfig(
return value
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
if data.get("optimizer") == "muon" and (
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
):
raise ValueError(
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
@@ -1262,3 +1281,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
@model_validator(mode="after")
def check_min_torch_version(self):
if self.env_capabilities and self.env_capabilities.torch_version:
torch_version = self.env_capabilities.torch_version
if version.parse(torch_version) < version.parse("2.5.1"):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
)
return self

View File

@@ -22,6 +22,7 @@ class ChatTemplate(str, Enum):
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
@@ -36,6 +37,10 @@ class ChatTemplate(str, Enum):
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
pixtral = "pixtral" # pylint: disable=invalid-name
llava = "llava" # pylint: disable=invalid-name
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
gemma3 = "gemma3" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):

View File

@@ -0,0 +1,48 @@
"""Pydantic models for multimodal-related configuration"""
from typing import Literal
from PIL.Image import Resampling
from pydantic import BaseModel, Field, field_validator
class MultiModalConfig(BaseModel):
"""Multi-modal configuration subset"""
image_size: int | tuple[int, int] | None = Field(
default=None,
json_schema_extra={
"description": (
"The size of the image to resize to. It can be an integer (resized into padded-square image) or a tuple (width, height)."
"If not provided, we will attempt to load from preprocessor.size, otherwise, images won't be resized."
)
},
)
image_resize_algorithm: (
Literal["bilinear", "bicubic", "lanczos"] | Resampling | None
) = Field(
default=None,
json_schema_extra={
"description": "The resampling algorithm to use for image resizing. Default is bilinear. Please refer to PIL.Image.Resampling for more details."
},
)
@field_validator("image_resize_algorithm", mode="before")
@classmethod
def convert_image_resize_algorithm(cls, image_resize_algorithm):
"""
Convert the image resize algorithm to a PIL.Image.Resampling enum.
"""
if isinstance(image_resize_algorithm, str):
image_resize_algorithm = image_resize_algorithm.lower()
if image_resize_algorithm == "bilinear":
image_resize_algorithm = Resampling.BILINEAR
elif image_resize_algorithm == "bicubic":
image_resize_algorithm = Resampling.BICUBIC
elif image_resize_algorithm == "lanczos":
image_resize_algorithm = Resampling.LANCZOS
else:
raise ValueError(
f"Invalid image resize algorithm: {image_resize_algorithm}"
)
return image_resize_algorithm

View File

@@ -20,27 +20,30 @@ class TRLConfig(BaseModel):
)
# GRPO specific args
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
use_vllm: bool | None = Field(
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
use_vllm: bool = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
)
vllm_device: str | None = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
vllm_server_host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host of the vLLM server to connect to"},
)
vllm_gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
vllm_server_port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to connect to"},
)
vllm_dtype: str | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
vllm_max_model_len: int | None = Field(
vllm_server_timeout: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_guided_decoding_regex: str | None = Field(
default=None,
json_schema_extra={
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
},
)
@@ -85,3 +88,48 @@ class TRLConfig(BaseModel):
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
)
scale_rewards: bool = Field(
default=True,
json_schema_extra={
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
},
)
temperature: float | None = Field(
default=None,
json_schema_extra={"description": "Sampling temperature for the GRPO policy."},
)
top_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Top-p sampling probability for the generation policy."
},
)
top_k: int | None = Field(
default=None,
json_schema_extra={"description": "Top-k sampling for the generation policy."},
)
min_p: float | None = Field(
default=None,
json_schema_extra={
"description": "Minimum probability for the generation policy."
},
)
repetition_penalty: float | None = Field(
default=None,
json_schema_extra={
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
},
)
num_iterations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
},
)
epsilon: float | None = Field(
default=None,
json_schema_extra={
"description": "Epsilon value for clipping in the GRPO algorithm."
},
)

View File

@@ -0,0 +1,38 @@
"""
Pydantic models for VLLM configuration, used primarily for RL training with TRL + grpo
"""
from pydantic import BaseModel, Field
class VllmConfig(BaseModel):
"""
Configuration for VLLM server
"""
device: str | None = Field(
default="auto",
json_schema_extra={"description": "Device to use for VLLM"},
)
tensor_parallel_size: int | None = Field(
default=None,
json_schema_extra={"description": "Tensor parallel size for VLLM"},
)
gpu_memory_utilization: float | None = Field(
default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"},
)
dtype: str | None = Field(
default="auto",
json_schema_extra={"description": "Data type for VLLM"},
)
max_model_len: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the model context for VLLM"
},
)
enable_prefix_caching: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)

View File

@@ -13,7 +13,7 @@ import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -235,7 +235,7 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.model_config_type == "mamba":
if cfg.model_config_type in ["mamba", "gemma3"]:
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
else:
sampler_batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len
if cfg.curriculum_sampling:
sampler = SequentialSampler(train_dataset)
else:
sampler = RandomSampler(train_dataset)
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
sampler=sampler,
lengths=get_dataset_lengths(train_dataset),
batch_size=sampler_batch_size,
batch_max_len=batch_max_len,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
sequential=cfg.sample_packing_sequentially,
drop_last=True,
)

0
tests/__init__.py Normal file
View File

View File

@@ -8,10 +8,16 @@ import shutil
import sys
import tempfile
import time
from pathlib import Path
import datasets
import pytest
import requests
from huggingface_hub import snapshot_download
from tokenizers import AddedToken
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
def retry_on_request_exceptions(max_retries=3, delay=1):
@@ -25,9 +31,11 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
wait = 2**attempt * delay # in seconds
time.sleep(wait)
else:
raise exc
@@ -37,26 +45,35 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
@retry_on_request_exceptions(max_retries=3, delay=5)
@disable_hf_offline
def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)
def download_ds_fixture_bundle():
ds_dir = snapshot_download_w_retry(
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
)
return Path(ds_dir)
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model():
# download the model
snapshot_download_w_retry("JackFram/llama-68m")
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
@@ -94,13 +111,52 @@ def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
def download_argilla_distilabel_intel_orca_dpo_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
# @pytest.fixture(scope="session", autouse=True)
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# # download the dataset
# snapshot_download_w_retry(
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
# )
# @pytest.fixture(scope="session", autouse=True)
# def download_fozzie_alpaca_dpo_dataset():
# # download the dataset
# snapshot_download_w_retry(
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
# )
# snapshot_download_w_retry(
# "fozziethebeat/alpaca_messages_2k_dpo_test",
# repo_type="dataset",
# revision="ea82cff",
# )
# @pytest.fixture(scope="session")
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
#
#
# @pytest.fixture(scope="session")
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# return load_dataset(
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
# )
@pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
@@ -109,10 +165,192 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
)
@pytest.fixture(scope="session", autouse=True)
def download_argilla_dpo_pairs_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/distilabel-intel-orca-dpo-pairs", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_tiny_shakespeare_dataset():
# download the dataset
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_huggyllama_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"huggyllama/llama-7b",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-3.2-1B",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_llama3_8b_instruct_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Meta-Llama-3-8B-Instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_35_mini_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3.5-mini-instruct", repo_type="model", allow_patterns=["*token*"]
)
@pytest.fixture(scope="session", autouse=True)
def download_phi_3_medium_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"microsoft/Phi-3-medium-128k-instruct",
repo_type="model",
allow_patterns=["*token*"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"casperhansen/mistral-7b-instruct-v0.1-awq",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma_2b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"unsloth/gemma-2b-it",
revision="703fb4a",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_gemma2_9b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/gemma-2-9b-it-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_mlx_mistral_7b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"NousResearch/Llama-2-7b-hf",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama_w_special_tokens(
tokenizer_huggyllama,
): # pylint: disable=redefined-outer-name
tokenizer_huggyllama.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
return tokenizer_huggyllama
@pytest.fixture
@enable_hf_offline
def tokenizer_llama2_7b(
download_llama2_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_mistral_7b_instruct(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
@pytest.fixture
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
tokenizer_mistral_7b_instruct.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer_mistral_7b_instruct.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return tokenizer_mistral_7b_instruct
@pytest.fixture
@@ -178,3 +416,88 @@ def cleanup_monkeypatches():
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)
@pytest.fixture
def dataset_winglian_tiny_shakespeare(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
return datasets.load_from_disk(ds_path)
@pytest.fixture
def dataset_tatsu_lab_alpaca(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_mhenrichsen_alpaca_2k_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
)
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
)
return datasets.load_from_disk(ds_path)["train"]
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass

View File

@@ -10,10 +10,13 @@ from transformers import AddedToken, AutoTokenizer
from axolotl.core.chat.format.chatml import format_message
from axolotl.core.chat.messages import ChatFormattedChats, Chats
from tests.hf_offline_utils import enable_hf_offline # noqa
@pytest.fixture(scope="session", name="llama_tokenizer")
@enable_hf_offline
def llama_tokenizer_fixture():
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3.1-8B")
return AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
@pytest.fixture(scope="session", name="chatml_tokenizer")

View File

@@ -5,7 +5,6 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
@@ -13,6 +12,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):

View File

@@ -2,15 +2,13 @@
Simple end-to-end test for Liger integration
"""
from e2e.utils import require_torch_2_4_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
class LigerIntegrationTestCase:

View File

View File

@@ -0,0 +1,294 @@
"""
GRPO test suite
"""
import os
import random
import subprocess # nosec B404
import sys
import time
from pathlib import Path
import pytest
import requests
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_vllm
def start_vllm(
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> int:
"""
helper function to start the VLLM server in the background, mostly for testing purposes
"""
cmd = [sys.executable, "-m", "trl.scripts.vllm_serve", "--model", model]
if tensor_parallel_size := kwargs.get("tensor_parallel_size"):
cmd.extend(["--tensor-parallel-size", str(tensor_parallel_size)])
if host := kwargs.get("host"):
cmd.extend(["--host", host])
if port := kwargs.get("port"):
cmd.extend(["--port", str(port)])
if gpu_memory_utilization := kwargs.get("gpu_memory_utilization"):
cmd.extend(["--gpu-memory-utilization", str(gpu_memory_utilization)])
if dtype := kwargs.get("dtype"):
cmd.extend(["--dtype", dtype])
if max_model_len := kwargs.get("max_model_len"):
cmd.extend(["--max-model-len", str(max_model_len)])
if kwargs.get("enable_prefix_caching"):
cmd.extend(["--enable-prefix-caching", "True"])
# print out the command to be executed
print(" ".join(cmd))
# start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with
cmd,
env=env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603
# print out the process id so the user can easily kill it later
print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
started = False
if wait and host and port:
for _ in range(int(wait)):
try:
response = requests.get(f"http://{host}:{port}", timeout=1)
if int(response.status_code) in [200, 404]:
started = True
break
except requests.exceptions.RequestException:
pass
# also check if the process.pid is still running
if not process.poll() is None:
break
time.sleep(1)
if wait and not started:
print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
)
process.kill()
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process id
return process.pid
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process_id = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
)
finally:
os.kill(vllm_process_id, 9)

View File

@@ -52,9 +52,9 @@ class TestMultiGPUEval:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -121,9 +121,9 @@ class TestMultiGPUEval:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",

View File

@@ -0,0 +1,100 @@
"""
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@pytest.fixture(scope="session", autouse=True)
def download_model():
# download the model
snapshot_download("axolotl-mirrors/gemma-3-4b-pt", repo_type="model")
class TestMultiGPUGemma3:
"""
Test case for Gemma3 models using LoRA
"""
def test_lora_ddp_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
"sequence_len": 2048,
"ddp_find_unused_parameters": True,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.0,
"chat_template": "gemma3",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {
"use_reentrant": False,
},
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
)

View File

@@ -1,174 +0,0 @@
"""
GRPO test suite
"""
import random
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import require_vllm
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
"""
def _utils_write_yaml_and_rewards(self, cfg, temp_dir, suffix=""):
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
with open(f"rewards_{suffix}.py", "w", encoding="utf-8") as fout:
fout.write(
"""import random
def rand_reward_func(completions, **kwargs) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
def oai_gsm8k_transform(cfg, *args, **kwargs):
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [{"role": "user", "content": example["question"]},],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
"""
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_dora(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"peft_use_dora": True,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
)
@require_vllm
def test_llama_fft(self, temp_dir, num_gpus):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"vllm_device": "auto" if num_gpus == 1 else "cuda:1",
"vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 5,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(num_gpus),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)

View File

@@ -9,12 +9,13 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -57,6 +58,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -120,6 +122,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -192,6 +195,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -269,6 +273,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
@@ -329,6 +334,7 @@ class TestMultiGPULlama:
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -398,7 +404,8 @@ class TestMultiGPULlama:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -477,7 +484,8 @@ class TestMultiGPULlama:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
@@ -777,9 +785,10 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",

View File

@@ -46,7 +46,7 @@ class TestMultiGPUQwen2:
},
],
"num_epochs": 1,
"max_steps": 5,
"max_steps": 2,
"warmup_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,

View File

@@ -9,10 +9,11 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"
@@ -49,7 +50,7 @@ class TestMultiGPURay:
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",

View File

@@ -144,7 +144,7 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda"
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="auto"
)
peft_config = get_peft_config(
{
@@ -347,7 +347,7 @@ def test_model_architecture(model_config):
"""Test LoRA kernel patches across different model architectures."""
# Load model with appropriate dtype
model = AutoModelForCausalLM.from_pretrained(
model_config["name"], torch_dtype=model_config["dtype"], device_map="cuda"
model_config["name"], torch_dtype=model_config["dtype"], device_map="auto"
)
# Apply LoRA configuration
@@ -408,7 +408,7 @@ def test_kernel_training_integration():
)
# Load model
model, _ = load_model_and_tokenizer(cfg=cfg)
model, _, _ = load_model_and_tokenizer(cfg=cfg)
# Verify correct activation function
layer = model.model.model.layers[0]

Some files were not shown because too many files have changed in this diff Show More