Compare commits

..

31 Commits

Author SHA1 Message Date
Wing Lian
14d670dbf0 v0.9.0 release (#2578)
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.4.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-04-28 18:23:17 -04:00
Wing Lian
2d77165dc0 automatically split out reasoning trace from dataset (#2579)
* automatically split out reasoning trace from dataset

* chore: lint

* fix import
2025-04-28 18:23:03 -04:00
Wing Lian
63b17e3109 chat template and example for qwen3 (#2577) 2025-04-28 15:09:41 -04:00
NanoCode012
1178a15ede Feat: Add qwen3 and CCE for qwen family (#2518) 2025-04-28 12:18:46 -04:00
Wing Lian
c513487d1a support val_set_size for splitting test split from train with DPO (#2572) 2025-04-28 12:12:15 -04:00
Dan Saunders
dda95e6c40 add preview-docs workflow (#2432)
* add preview-docs workflow

* update preview-docs workflow

* use correct publish-dir

* install deps prior to docs build

* use correct publish-dir

* use quarto publish with netlify target

* adding _publish.yml

* fix

* fix

* fix

* remove unused file

* fix naming

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-04-28 11:20:46 -04:00
NanoCode012
7099343c56 feat: add eos_tokens and train_on_eot for chat_template EOT parsing (#2364)
* feat: add eos_tokens and train_on_eot for chat_template EOT parsing

* fix: comments

* chore: add some examples of tokens

* feat: add new potential errors for chat_template to faq

* feat: add examples for EOT handling

* fix: change error to warning for missing EOS

* fix: warning typo

* feat: add tests for eot token handling

* fix: remove broken caplog capture in test

* fix: chattemplate strategy with kd missing eot changes
2025-04-28 10:11:20 -04:00
Wing Lian
5000cb3fe7 grab sys prompt too from dataset (#2397) [skip ci]
* grab sys prompt too from dataset

* chore: add field_system to docs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-04-28 10:11:06 -04:00
divyanshuaggarwal
170cdb5be9 Add Post_model_load, post_lora_load, post_train, post_train_unload function calls (#2539)
* Update train.py

add post_model_load and post_lora_load model calss.

* Update train.py

add post_train and post_train_unload function calls

* Update train.py

* Update base.py

* Update train.py

* chore: lint

* clarify plugin hooks

* Update src/axolotl/integrations/base.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/utils/models.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/utils/models.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update src/axolotl/integrations/base.py

Co-authored-by: Dan Saunders <danjsaund@gmail.com>

* Update models.py

* Update models.py

* remove extra call to post_model_load

* chore: lint

* add test for hooks and gc trainer

* disable duplicated code check for test

* fix the path and add better handling

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-04-28 10:10:28 -04:00
Ezekiel Wotring
5d182a1056 Add runpod sls handler (#2530) [skip ci]
* Add runpod sls handler

* remove LICENSE and fix README

* chore: lint

* use axolotl cloud image as base and various fixes

* fix: trim allowed cuda versions

* restore dockerfile

* chore: update title

* use axolotl cloud image

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-04-28 10:08:32 -04:00
Wing Lian
40f4ea23ab replace references to random 68m model w 135m smollm2 (#2570) [skip ci]
* replace references to random 68m model w 135m smollm2

* use AutoTokenizer for smollm2
2025-04-28 10:08:07 -04:00
NanoCode012
f1df73a798 fix(doc): clarify vllm usage with grpo (#2573) [skip ci]
* fix(doc): clarify vllm usage with grpo

* nit

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/rlhf.qmd

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-04-28 10:07:45 -04:00
Dhruv Mullick
8b33ae1c4f Fix bug in grpo reward module import (#2571) 2025-04-28 00:31:56 -04:00
Wing Lian
dc4da4a7e2 update trl to 0.17.0 (#2560)
* update trl to 0.17.0

* grpo + vllm no longer supported with 2.5.1 due to vllm constraints

* disable VLLM_USE_V1 for ci

* imporve handle killing off of multiprocessing vllm service

* debug why this doesn't run in CI

* increase vllm wait time

* increase timeout to 5min

* upgrade to vllm 0.8.4

* dump out the vllm log for debugging

* use debug logging

* increase vllm start timeout

* use NVL instead

* disable torch compile cache

* revert some commented checks now that grpo tests are fixed

* increase vllm timeoout back to 5min
2025-04-27 19:19:53 -04:00
Wing Lian
f9c7c3bb72 don't use is_main_process during config validation (#2569) 2025-04-26 14:14:52 -04:00
Wing Lian
caf5cb63ea add e2e smoke test for using activation/gradient checkpointing with offload (#2565)
* add e2e smoke test for using activation/gradient checkpointing with offload

* disable duplicate code check for the test

* fix relative import

* seq len too small to test this dataset with packing

* Fix checkpoint ptaching for tests
2025-04-25 21:11:17 -04:00
Wing Lian
5dba5c82a8 fix support for wandb run_name for rl trainers (#2566) [skip ci]
* fix support for wandb run_name for rl trainers

* prefer to use wandb random names for run_name
2025-04-25 21:10:54 -04:00
Chiwan Park
e3c9d541a7 fix: crash when pretraining_dataset with dispatch_batches is false (#2558) 2025-04-25 17:15:03 -04:00
NanoCode012
9eba0ad118 chore(doc): update docker tags on doc (#2559) [skip ci] 2025-04-25 17:14:48 -04:00
Wing Lian
53dbf97d85 make cce default to true when using the plugin (#2562) [skip ci] 2025-04-25 17:14:26 -04:00
Eko Julianto Salim
2c2563bc34 fix: gradient checkpointing functools.partial object has no attribute __self__ (#2563) [skip ci]
* fix: gradient checkpointing causing functools.partial error

* lint

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-25 17:02:37 -04:00
Wing Lian
5cb3398460 don't fail on codecov upload for external contributor PRs (#2564) [skip ci] 2025-04-25 15:10:55 -04:00
Dan Saunders
ae1c7ace63 Sequence parallel training context manager (#2553)
* ctx manager for SP

* updates

* update

* further simplifying

* accommodate both training context managers

* simplifying

* simplifying

* nit

* reorg

* tweak codecov yaml

* add gather post hook, simplify, fixes

* pytest

* pytest fix
2025-04-25 10:33:54 -04:00
Wing Lian
1447beb132 make sure to validate the config before normalizing so defaults get set (#2554)
* make sure to validate the config before normalizing so defaults get set

* validation not needed for particular test

* remove duplicate validations

* set qlora correctly
2025-04-24 13:01:43 -04:00
Dan Saunders
66f41ec6f1 disable codecov pr annotations (#2556) 2025-04-24 08:51:51 -04:00
NanoCode012
85053f4bd4 Fix(doc): add delinearize instruction (#2545)
* fix: mention to install pytorch before axolotl

* feat(doc): include instruction to delinearize

* fix: update instruction for delinearize with adapter
2025-04-24 01:03:43 -04:00
Wing Lian
a4d5112ae1 builds for torch 2.7.0 (#2552)
* builds for torch==2.7.0

* use xformers==0.0.29.post3

* no vllm support with torch 2.7

* update default, fix conditional

* no xformers for 270

* no vllm on 2.7.0 for multigpu test too

* remove deprecated verbose arg from scheduler

* 2.7.0 tests on cpu
2025-04-24 00:39:31 -04:00
Wing Lian
0d691cc2a7 add base docker image with pytorch 2.7.0 and variant for cuda 12.8 (#2551)
* add base docker image with pytorch 2.7.0 and variant for cuda 12.8

* my bash is terrible
2025-04-23 14:59:03 -04:00
Dan Saunders
c4053481ff Codecov fixes / improvements (#2549)
* adding codecov reporting

* random change

* codecov fixes

* adding missing dependency

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-04-23 10:33:30 -04:00
NanoCode012
a6d28d19b1 feat: add glm and glm4 multipack and cce (#2546)
* feat: add glm and glm4 multipack

* feat: add glm4 example

* feat: add cce for glm
2025-04-23 10:27:51 -04:00
Wing Lian
32e335dd51 fix missing host/port for vllm (#2543)
* fix missing host/port for vllm

* set tensor parallel size so it doesn't always default to cli override
2025-04-22 10:16:48 -04:00
118 changed files with 5317 additions and 932 deletions

View File

@@ -46,6 +46,18 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -24,13 +24,18 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras: vllm
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -93,6 +98,11 @@ jobs:
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -138,7 +148,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -8,6 +8,7 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
@@ -42,7 +43,14 @@ jobs:
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
@@ -67,6 +75,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

55
.github/workflows/preview-docs.yml vendored Normal file
View File

@@ -0,0 +1,55 @@
name: Preview
on:
workflow_dispatch:
pull_request:
types: [opened, synchronize, reopened]
permissions:
checks: write
contents: write
deployments: write
issues: write
discussions: write
pages: write
pull-requests: write
statuses: write
jobs:
preview:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python3 -m pip install jupyter quartodoc
python3 -m pip install -e . --no-deps
- name: Build autodoc
run: quartodoc build
- name: Quarto render
run: quarto render
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
with:
publish-dir: './_site'
enable-pull-request-comment: true
enable-github-deployment: true
github-token: ${{ secrets.GITHUB_TOKEN }}
deploy-message: "Deployed On Netlify"
github-deployment-environment: 'preview'
github-deployment-description: 'Preview Deployment'
env:
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}

View File

@@ -147,6 +147,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -109,6 +109,7 @@ jobs:
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
@@ -241,6 +242,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -267,7 +269,13 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -288,6 +296,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

161
.runpod/.gitignore vendored Normal file
View File

@@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
pod/scripts/config.yaml

18
.runpod/Dockerfile Normal file
View File

@@ -0,0 +1,18 @@
FROM axolotlai/axolotl-cloud:main-py3.11-cu124-2.6.0
COPY .runpod/requirements.txt /requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade pip && \
python3 -m pip install --upgrade -r /requirements.txt
# Environment settings
ARG BASE_VOLUME="/runpod-volume"
ENV BASE_VOLUME=$BASE_VOLUME
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
COPY .runpod/src /src
WORKDIR /src
CMD ["python3", "/src/handler.py"]

335
.runpod/README.md Normal file
View File

@@ -0,0 +1,335 @@
<h1>LLM Post Training- Full fine-tune, LoRA, QLoRa etc. Llama/Mistral/Gemma and more</h1>
# Configuration Options
This document outlines all available configuration options for training models. The configuration can be provided as a JSON request.
## Usage
You can use these configuration Options:
1. As a JSON request body:
```json
{
"input": {
"user_id": "user",
"model_id": "model-name",
"run_id": "run-id",
"credentials": {
"wandb_api_key": "", # add your Weights & biases key. TODO: you will be able to set this in Enviornment variables.
"hf_token": "", # add your HF_token. TODO: you will be able to set this in Enviornment variables.
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
// ... other options
}
}
}
```
## Configuration Options
### Model Configuration
| Option | Description | Default |
| ------------------- | --------------------------------------------------------------------------------------------- | -------------------- |
| `base_model` | Path to the base model (local or HuggingFace) | Required |
| `base_model_config` | Configuration path for the base model | Same as base_model |
| `revision_of_model` | Specific model revision from HuggingFace hub | Latest |
| `tokenizer_config` | Custom tokenizer configuration path | Optional |
| `model_type` | Type of model to load | AutoModelForCausalLM |
| `tokenizer_type` | Type of tokenizer to use | AutoTokenizer |
| `hub_model_id` | Repository ID where the model will be pushed on Hugging Face Hub (format: username/repo-name) | Optional |
## Model Family Identification
| Option | Default | Description |
| -------------------------- | ------- | ------------------------------ |
| `is_falcon_derived_model` | `false` | Whether model is Falcon-based |
| `is_llama_derived_model` | `false` | Whether model is LLaMA-based |
| `is_qwen_derived_model` | `false` | Whether model is Qwen-based |
| `is_mistral_derived_model` | `false` | Whether model is Mistral-based |
## Model Configuration Overrides
| Option | Default | Description |
| ----------------------------------------------- | ---------- | ---------------------------------- |
| `overrides_of_model_config.rope_scaling.type` | `"linear"` | RoPE scaling type (linear/dynamic) |
| `overrides_of_model_config.rope_scaling.factor` | `1.0` | RoPE scaling factor |
### Model Loading Options
| Option | Description | Default |
| -------------- | ----------------------------- | ------- |
| `load_in_8bit` | Load model in 8-bit precision | false |
| `load_in_4bit` | Load model in 4-bit precision | false |
| `bf16` | Use bfloat16 precision | false |
| `fp16` | Use float16 precision | false |
| `tf32` | Use tensor float 32 precision | false |
## Memory and Device Settings
| Option | Default | Description |
| ------------------ | --------- | ----------------------- |
| `gpu_memory_limit` | `"20GiB"` | GPU memory limit |
| `lora_on_cpu` | `false` | Load LoRA on CPU |
| `device_map` | `"auto"` | Device mapping strategy |
| `max_memory` | `null` | Max memory per device |
## Training Hyperparameters
| Option | Default | Description |
| ----------------------------- | --------- | --------------------------- |
| `gradient_accumulation_steps` | `1` | Gradient accumulation steps |
| `micro_batch_size` | `2` | Batch size per GPU |
| `eval_batch_size` | `null` | Evaluation batch size |
| `num_epochs` | `4` | Number of training epochs |
| `warmup_steps` | `100` | Warmup steps |
| `warmup_ratio` | `0.05` | Warmup ratio |
| `learning_rate` | `0.00003` | Learning rate |
| `lr_quadratic_warmup` | `false` | Quadratic warmup |
| `logging_steps` | `null` | Logging frequency |
| `eval_steps` | `null` | Evaluation frequency |
| `evals_per_epoch` | `null` | Evaluations per epoch |
| `save_strategy` | `"epoch"` | Checkpoint saving strategy |
| `save_steps` | `null` | Saving frequency |
| `saves_per_epoch` | `null` | Saves per epoch |
| `save_total_limit` | `null` | Maximum checkpoints to keep |
| `max_steps` | `null` | Maximum training steps |
### Dataset Configuration
```yaml
datasets:
- path: vicgalle/alpaca-gpt4 # HuggingFace dataset or TODO: You will be able to add the local path.
type: alpaca # Format type (alpaca, gpteacher, oasst, etc.)
ds_type: json # Dataset type
data_files: path/to/data # Source data files
train_on_split: train # Dataset split to use
```
## Chat Template Settings
| Option | Default | Description |
| ------------------------ | -------------------------------- | ---------------------- |
| `chat_template` | `"tokenizer_default"` | Chat template type |
| `chat_template_jinja` | `null` | Custom Jinja template |
| `default_system_message` | `"You are a helpful assistant."` | Default system message |
## Dataset Processing
| Option | Default | Description |
| ----------------------------- | -------------------------- | --------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
## LoRA Configuration
| Option | Default | Description |
| -------------------------- | ---------------------- | ------------------------------ |
| `adapter` | `"lora"` | Adapter type (lora/qlora) |
| `lora_model_dir` | `""` | Directory with pretrained LoRA |
| `lora_r` | `8` | LoRA attention dimension |
| `lora_alpha` | `16` | LoRA alpha parameter |
| `lora_dropout` | `0.05` | LoRA dropout |
| `lora_target_modules` | `["q_proj", "v_proj"]` | Modules to apply LoRA |
| `lora_target_linear` | `false` | Target all linear modules |
| `peft_layers_to_transform` | `[]` | Layers to transform |
| `lora_modules_to_save` | `[]` | Modules to save |
| `lora_fan_in_fan_out` | `false` | Fan in/out structure |
## Optimization Settings
| Option | Default | Description |
| ------------------------- | ------- | -------------------------- |
| `train_on_inputs` | `false` | Train on input prompts |
| `group_by_length` | `false` | Group by sequence length |
| `gradient_checkpointing` | `false` | Use gradient checkpointing |
| `early_stopping_patience` | `3` | Early stopping patience |
## Learning Rate Scheduling
| Option | Default | Description |
| -------------------------- | ---------- | -------------------- |
| `lr_scheduler` | `"cosine"` | Scheduler type |
| `lr_scheduler_kwargs` | `{}` | Scheduler parameters |
| `cosine_min_lr_ratio` | `null` | Minimum LR ratio |
| `cosine_constant_lr_ratio` | `null` | Constant LR ratio |
| `lr_div_factor` | `null` | LR division factor |
## Optimizer Settings
| Option | Default | Description |
| ---------------------- | ------------ | ------------------- |
| `optimizer` | `"adamw_hf"` | Optimizer choice |
| `optim_args` | `{}` | Optimizer arguments |
| `optim_target_modules` | `[]` | Target modules |
| `weight_decay` | `null` | Weight decay |
| `adam_beta1` | `null` | Adam beta1 |
| `adam_beta2` | `null` | Adam beta2 |
| `adam_epsilon` | `null` | Adam epsilon |
| `max_grad_norm` | `null` | Gradient clipping |
## Attention Implementations
| Option | Default | Description |
| -------------------------- | ------- | ----------------------------- |
| `flash_optimum` | `false` | Use better transformers |
| `xformers_attention` | `false` | Use xformers |
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |
## Tokenizer Modifications
| Option | Default | Description |
| ---------------- | ------- | ---------------------------- |
| `special_tokens` | - | Special tokens to add/modify |
| `tokens` | `[]` | Additional tokens |
## Distributed Training
| Option | Default | Description |
| ----------------------- | ------- | --------------------- |
| `fsdp` | `null` | FSDP configuration |
| `fsdp_config` | `null` | FSDP config options |
| `deepspeed` | `null` | Deepspeed config path |
| `ddp_timeout` | `null` | DDP timeout |
| `ddp_bucket_cap_mb` | `null` | DDP bucket capacity |
| `ddp_broadcast_buffers` | `null` | DDP broadcast buffers |
<details>
<summary><h3>Example Configuration Request:</h3></summary>
Here's a complete example for fine-tuning a LLaMA model using LoRA:
```json
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "test-run",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Llama-3.2-1B",
"load_in_8bit": false,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca"
}
],
"dataset_prepared_path": "last_run_prepared",
"val_set_size": 0.1,
"output_dir": "./outputs/lora-out",
"adapter": "lora",
"sequence_len": 2048,
"sample_packing": true,
"eval_sample_packing": true,
"pad_to_sequence_len": true,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_modules": [
"gate_proj",
"down_proj",
"up_proj",
"q_proj",
"v_proj",
"k_proj",
"o_proj"
],
"gradient_accumulation_steps": 2,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"loss_watchdog_threshold": 5,
"loss_watchdog_patience": 3,
"warmup_steps": 10,
"evals_per_epoch": 4,
"saves_per_epoch": 1,
"weight_decay": 0,
"hub_model_id": "runpod/llama-fr-lora",
"wandb_name": "test-run-1",
"wandb_project": "test-run-1",
"wandb_entity": "axo-test",
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}
```
</details>
### Advanced Features
#### Wandb Integration
- `wandb_project`: Project name for Weights & Biases
- `wandb_entity`: Team name in W&B
- `wandb_watch`: Monitor model with W&B
- `wandb_name`: Name of the W&B run
- `wandb_run_id`: ID for the W&B run
#### Performance Optimization
- `sample_packing`: Enable efficient sequence packing
- `eval_sample_packing`: Use sequence packing during evaluation
- `torch_compile`: Enable PyTorch 2.0 compilation
- `flash_attention`: Use Flash Attention implementation
- `xformers_attention`: Use xFormers attention implementation
### Available Optimizers
The following optimizers are supported:
- `adamw_hf`: HuggingFace's AdamW implementation
- `adamw_torch`: PyTorch's AdamW
- `adamw_torch_fused`: Fused AdamW implementation
- `adamw_torch_xla`: XLA-optimized AdamW
- `adamw_apex_fused`: NVIDIA Apex fused AdamW
- `adafactor`: Adafactor optimizer
- `adamw_anyprecision`: Anyprecision AdamW
- `adamw_bnb_8bit`: 8-bit AdamW from bitsandbytes
- `lion_8bit`: 8-bit Lion optimizer
- `lion_32bit`: 32-bit Lion optimizer
- `sgd`: Stochastic Gradient Descent
- `adagrad`: Adagrad optimizer
## Notes
- Set `load_in_8bit: true` or `load_in_4bit: true` for memory-efficient training
- Enable `flash_attention: true` for faster training on modern GPUs
- Use `gradient_checkpointing: true` to reduce memory usage
- Adjust `micro_batch_size` and `gradient_accumulation_steps` based on your GPU memory
For more detailed information, please refer to the [documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html).
### Errors:
- if you face any issues with the Flash Attention-2, Delete yoor worker and Re-start.

93
.runpod/hub.json Normal file
View File

@@ -0,0 +1,93 @@
{
"title": "Axolotl Fine-Tuning",
"description": "Serverless fine-tuning of open-source LLMs with Axolotl. Supports LoRA, QLoRA, DPO, and more using Hugging Face models and datasets.",
"type": "serverless",
"category": "language",
"iconUrl": "https://avatars.githubusercontent.com/u/167502477",
"config": {
"runsOn": "GPU",
"containerDiskInGb": 200,
"gpuCount": 1,
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
],
"presets": [],
"env": [
{
"key": "TOKENIZER",
"input": {
"name": "Tokenizer",
"type": "string",
"description": "Name or path of the Hugging Face tokenizer to use.",
"default": "",
"advanced": true
}
},
{
"key": "MAX_NUM_SEQS",
"input": {
"name": "Max Num Seqs",
"type": "number",
"description": "Maximum number of sequences per iteration.",
"default": 256,
"advanced": true
}
},
{
"key": "DISABLE_LOG_STATS",
"input": {
"name": "Disable Log Stats",
"type": "boolean",
"description": "Disable logging statistics.",
"default": false,
"trueValue": "true",
"falseValue": "false"
}
},
{
"key": "LOAD_FORMAT",
"input": {
"name": "Load Format",
"type": "string",
"description": "The format of the model weights to load.",
"default": "auto",
"options": [
{
"label": "auto",
"value": "auto"
},
{
"label": "pt",
"value": "pt"
},
{
"label": "safetensors",
"value": "safetensors"
},
{
"label": "npcache",
"value": "npcache"
},
{
"label": "dummy",
"value": "dummy"
},
{
"label": "tensorizer",
"value": "tensorizer"
},
{
"label": "bitsandbytes",
"value": "bitsandbytes"
}
],
"advanced": true
}
}
]
}
}

7
.runpod/requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
# Required Python packages get listed here, one per line.
# Reccomended to lock the version number to avoid unexpected changes.
# You can also install packages from a git repository, e.g.:
# git+https://github.com/runpod/runpod-python.git
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
runpod~=1.7.0

View File

@@ -0,0 +1,577 @@
# # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# # This can also be a relative path to a model on disk
# base_model: ./llama-7b-hf
# # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
# base_model_ignore_patterns:
# # If the base_model repo on hf hub doesn't include configuration .json files,
# # You can set that here, or leave this empty to default to base_model
# base_model_config: ./llama-7b-hf
# # You can specify to choose a specific model revision from huggingface hub
# model_revision:
# # Optional tokenizer configuration override in case you want to use a different tokenizer
# # than the one defined in the base model
# tokenizer_config:
# # If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
# model_type: AutoModelForCausalLM
# # Corresponding tokenizer for the model AutoTokenizer is a good choice
# tokenizer_type: AutoTokenizer
# # Trust remote code for untrusted source
# trust_remote_code:
# # use_fast option for tokenizer loading from_pretrained, default to True
# tokenizer_use_fast:
# # Whether to use the legacy tokenizer setting, defaults to True
# tokenizer_legacy:
# # Resize the model embeddings when new tokens are added to multiples of 32
# # This is reported to improve training speed on some models
# resize_token_embeddings_to_32x:
# # Used to identify which the model is based on
# is_falcon_derived_model:
# is_llama_derived_model:
# # Please note that if you set this to true, `padding_side` will be set to "left" by default
# is_mistral_derived_model:
# is_qwen_derived_model:
# # optional overrides to the base model configuration
# model_config:
# # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
# rope_scaling:
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
# gptq_model_v1: false # v1 or v2
# # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
# load_in_8bit: true
# # Use bitsandbytes 4 bit
# load_in_4bit:
# # Use CUDA bf16
# bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
# # Use CUDA fp16
# fp16: true
# # Use CUDA tf32
# tf32: true # require >=ampere
# # No AMP (automatic mixed precision)
# bfloat16: true # require >=ampere
# float16: true
# # A list of one or more datasets to finetune the model with
# datasets:
# # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
# - path: vicgalle/alpaca-gpt4
# # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
# type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
# ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
# data_files: # Optional[str] path to source data files
# shards: # Optional[int] number of shards to split data into
# name: # Optional[str] name of dataset configuration to load
# train_on_split: train # Optional[str] name of dataset split to load from
# # Optional[str] fastchat conversation type, only used with type: sharegpt
# conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# field_human: # Optional[str]. Human key to use for conversation.
# field_model: # Optional[str]. Assistant key to use for conversation.
# # Custom user prompt
# - path: repo
# type:
# # The below are defaults. only set what's needed.
# system_prompt: ""
# system_format: "{system}"
# field_system: system
# field_instruction: instruction
# field_input: input
# field_output: output
# # Customizable to be single line or multi-line
# # 'format' can include {input}
# format: |-
# User: {instruction} {input}
# Assistant:
# # 'no_input_format' cannot include {input}
# no_input_format: "{instruction} "
# # For `completion` datsets only, uses the provided field instead of `text` column
# field:
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
# # subsequent training attempts load faster, relative path
# dataset_prepared_path: data/last_run_prepared
# # Push prepared dataset to hub
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
# # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
# hub_strategy:
# # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# # Required to be true when used in combination with `push_dataset_to_hub`
# hf_use_auth_token: # boolean
# # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
# val_set_size: 0.04
# # Num shards for whole dataset
# dataset_shard_num:
# # Index of shard to use for whole dataset
# dataset_shard_idx:
# # The maximum length of an input to train with, this should typically be less than 2048
# # as most models have a token/context limit of 2048
# sequence_len: 2048
# # Pad inputs so each step uses constant sized buffers
# # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
# pad_to_sequence_len:
# # Max sequence length to concatenate training samples together up to
# # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# # FutureWarning: This will soon be DEPRECATED
# max_packed_sequence_len: 1024
# # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
# sample_packing:
# # Set to 'false' if getting errors during eval with sample_packing on.
# eval_sample_packing:
# # You can set these packing optimizations AFTER starting a training at least once.
# # The trainer will provide recommended values for these values.
# sample_packing_eff_est:
# total_num_tokens:
# # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
# adapter: lora
# # If you already have a lora model trained that you want to load, put that here.
# # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
# lora_model_dir:
# # LoRA hyperparameters
# # For more details about the following options, see:
# # https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
# lora_r: 8
# lora_alpha: 16
# lora_dropout: 0.05
# lora_target_modules:
# - q_proj
# - v_proj
# # - k_proj
# # - o_proj
# # - gate_proj
# # - down_proj
# # - up_proj
# lora_target_linear: # If true, will target all linear layers
# # If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# # For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# # `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# # https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
# lora_modules_to_save:
# # - embed_tokens
# # - lm_head
# # Once you complete training, the model will be saved to the following directory.
# # If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# # Make sure `lora_model_dir` points to this directory if you want to use the trained model.
# lora_out_dir:
# lora_fan_in_fan_out: false
# # ReLoRA configuration
# # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
# relora_steps: # Number of steps per ReLoRA restart
# relora_warmup_steps: # Number of per-restart warmup steps
# relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# # wandb configuration if you're using it
# wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
# wandb_project: # Your wandb project name
# wandb_entity: # A wandb Team name if using a Team
# wandb_watch:
# wandb_run_id: # Set the name of your wandb run
# wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# # Where to save the full-finetuned model to
# output_dir: ./completed-model
# # Whether to use torch.compile and which backend to use
# torch_compile: # bool
# torch_compile_backend: # Optional[str]
# # Training hyperparameters
# # If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
# gradient_accumulation_steps: 1
# # The number of samples to include in each batch. This is the number of samples sent to each GPU.
# micro_batch_size: 2
# eval_batch_size:
# num_epochs: 4
# warmup_steps: 100 # cannot use with warmup_ratio
# warmup_ratio: 0.05 # cannot use with warmup_steps
# learning_rate: 0.00003
# lr_quadratic_warmup:
# logging_steps:
# save_strategy: # Set to `no` to skip checkpoint saves
# save_steps: # Leave empty to save at each epoch
# eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
# save_total_limit: # Checkpoints saved at a time
# # Maximum number of iterations to train for. It precedes num_epochs which means that
# # if both are set, num_epochs will not be guaranteed.
# # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
# max_steps:
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
# # May be slower to start, as it must download and sort the entire dataset.
# # Note that training loss may have an oscillating pattern with this enabled.
# group_by_length: false
# # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# gradient_checkpointing: false
# # Stop training after this many evaluation losses have increased in a row
# # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
# early_stopping_patience: 3
# # Specify a scheduler and kwargs to use with the optimizer
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler_kwargs:
# # For one_cycle optim
# lr_div_factor: # Learning rate div factor
# # For log_sweep optim
# log_sweep_min_lr:
# log_sweep_max_lr:
# # Specify optimizer
# # Valid values are driven by the Transformers OptimizerNames class, see:
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# #
# # Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# # torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# # in the examples/ for your model and fine-tuning use case.
# #
# # Valid values for 'optimizer' include:
# # - adamw_hf
# # - adamw_torch
# # - adamw_torch_fused
# # - adamw_torch_xla
# # - adamw_apex_fused
# # - adafactor
# # - adamw_anyprecision
# # - sgd
# # - adagrad
# # - adamw_bnb_8bit
# # - lion_8bit
# # - lion_32bit
# # - paged_adamw_32bit
# # - paged_adamw_8bit
# # - paged_lion_32bit
# # - paged_lion_8bit
# optimizer:
# # Specify weight decay
# weight_decay:
# # adamw hyperparams
# adam_beta1:
# adam_beta2:
# adam_epsilon:
# # Gradient clipping max norm
# max_grad_norm:
# # Augmentation techniques
# # NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# # currently only supported on Llama and Mistral
# noisy_embedding_alpha:
# # Whether to bettertransformers
# flash_optimum:
# # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
# xformers_attention:
# # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# sdp_attention:
# # Landmark attention (only llama)
# landmark_attention:
# # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# # LLaMA only
# xpos_rope:
# # Resume from a specific checkpoint dir
# resume_from_checkpoint:
# # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# # Be careful with this being turned on between different models.
# auto_resume_from_checkpoints: false
# # Don't mess with this, it's here for accelerate and torchrun
# local_rank:
# # Add or change special tokens.
# # If you add tokens here, you don't need to add them to the `tokens` list.
# special_tokens:
# # bos_token: "<s>"
# # eos_token: "</s>"
# # unk_token: "<unk>"
# # Add extra tokens.
# tokens:
# # FSDP
# fsdp:
# fsdp_config:
# # Deepspeed config path. e.g., deepspeed/zero3.json
# deepspeed:
# # Advanced DDP Arguments
# ddp_timeout:
# ddp_bucket_cap_mb:
# ddp_broadcast_buffers:
# # Path to torch distx for optim 'adamw_anyprecision'
# torchdistx_path:
# # Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize
# pretraining_dataset:
# # Debug mode
# debug:
# # Seed
# seed:
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
revision_of_model: ${REVISION_OF_MODEL}
tokenizer_config: ${TOKENIZER_CONFIG}
model_type: ${MODEL_TYPE}
tokenizer_type: ${TOKENIZER_TYPE}
trust_remote_code: ${TRUST_REMOTE_CODE}
tokenizer_use_fast: ${TOKENIZER_USE_FAST}
tokenizer_legacy: ${TOKENIZER_LEGACY}
resize_token_embeddings_to_32x: ${RESIZE_TOKEN_EMBEDDINGS_TO_32X}
is_falcon_derived_model: ${IS_FALCON_DERIVED_MODEL}
is_llama_derived_model: ${IS_LLAMA_DERIVED_MODEL}
is_qwen_derived_model: ${IS_QWEN_DERIVED_MODEL}
is_mistral_derived_model: ${IS_MISTRAL_DERIVED_MODEL}
overrides_of_model_config:
rope_scaling:
type: ${ROPE_SCALING_TYPE}
factor: ${ROPE_SCALING_FACTOR}
bnb_config_kwargs:
llm_int8_has_fp16_weight: ${BNB_LLM_INT8_HAS_FP16_WEIGHT}
bnb_4bit_quant_type: ${BNB_4BIT_QUANT_TYPE}
bnb_4bit_use_double_quant: ${BNB_4BIT_USE_DOUBLE_QUANT}
gptq: ${GPTQ}
load_in_8bit: ${LOAD_IN_8BIT}
load_in_4bit: ${LOAD_IN_4BIT}
bf16: ${BF16}
fp16: ${FP16}
tf32: ${TF32}
bfloat16: ${BFLOAT16}
float16: ${FLOAT16}
gpu_memory_limit: ${GPU_MEMORY_LIMIT}
lora_on_cpu: ${LORA_ON_CPU}
datasets:
- path: ${DATASET_PATH}
type: ${DATASET_TYPE}
ds_type: ${DATASET_DS_TYPE}
data_files: ${DATASET_DATA_FILES}
shards: ${DATASET_SHARDS}
name: ${DATASET_NAME}
train_on_split: ${DATASET_TRAIN_ON_SPLIT}
revision: ${DATASET_REVISION}
trust_remote_code: ${DATASET_TRUST_REMOTE_CODE}
rl: ${RL}
dpo_use_weighting: ${DPO_USE_WEIGHTING}
chat_template: ${CHAT_TEMPLATE}
chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
hf_use_auth_token: ${HF_USE_AUTH_TOKEN}
val_set_size: ${VAL_SET_SIZE}
dataset_shard_num: ${DATASET_SHARD_NUM}
dataset_shard_idx: ${DATASET_SHARD_IDX}
sequence_len: ${SEQUENCE_LEN}
pad_to_sequence_len: ${PAD_TO_SEQUENCE_LEN}
sample_packing: ${SAMPLE_PACKING}
eval_sample_packing: ${EVAL_SAMPLE_PACKING}
sample_packing_eff_est: ${SAMPLE_PACKING_EFF_EST}
total_num_tokens: ${TOTAL_NUM_TOKENS}
sample_packing_group_size: ${SAMPLE_PACKING_GROUP_SIZE}
sample_packing_bin_size: ${SAMPLE_PACKING_BIN_SIZE}
batch_flattening: ${BATCH_FLATTENING}
device_map: ${DEVICE_MAP}
max_memory: ${MAX_MEMORY}
adapter: ${ADAPTER}
lora_model_dir: ${LORA_MODEL_DIR}
lora_r: ${LORA_R}
lora_alpha: ${LORA_ALPHA}
lora_dropout: ${LORA_DROPOUT}
lora_target_modules:
- ${LORA_TARGET_MODULES}
lora_target_linear: ${LORA_TARGET_LINEAR}
peft_layers_to_transform: ${PEFT_LAYERS_TO_TRANSFORM}
lora_modules_to_save: ${LORA_MODULES_TO_SAVE}
lora_fan_in_fan_out: ${LORA_FAN_IN_FAN_OUT}
loraplus_lr_ratio: ${LORAPLUS_LR_RATIO}
loraplus_lr_embedding: ${LORAPLUS_LR_EMBEDDING}
peft:
loftq_config:
loftq_bits: ${LOFTQ_BITS}
relora_steps: ${RELORA_STEPS}
relora_warmup_steps: ${RELORA_WARMUP_STEPS}
relora_anneal_steps: ${RELORA_ANNEAL_STEPS}
relora_prune_ratio: ${RELORA_PRUNE_RATIO}
relora_cpu_offload: ${RELORA_CPU_OFFLOAD}
wandb_mode: ${WANDB_MODE}
wandb_project: ${WANDB_PROJECT}
wandb_entity: ${WANDB_ENTITY}
wandb_watch: ${WANDB_WATCH}
wandb_name: ${WANDB_NAME}
wandb_run_id: ${WANDB_RUN_ID}
wandb_log_model: ${WANDB_LOG_MODEL}
mlflow_tracking_uri: ${MLFLOW_TRACKING_URI}
mlflow_experiment_name: ${MLFLOW_EXPERIMENT_NAME}
mlflow_run_name: ${MLFLOW_RUN_NAME}
hf_mlflow_log_artifacts: ${HF_MLFLOW_LOG_ARTIFACTS}
use_comet: ${USE_COMET}
comet_api_key: ${COMET_API_KEY}
comet_workspace: ${COMET_WORKSPACE}
comet_project_name: ${COMET_PROJECT_NAME}
comet_experiment_key: ${COMET_EXPERIMENT_KEY}
comet_mode: ${COMET_MODE}
comet_online: ${COMET_ONLINE}
comet_experiment_config: ${COMET_EXPERIMENT_CONFIG}
output_dir: ${OUTPUT_DIR}
torch_compile: ${TORCH_COMPILE}
torch_compile_backend: ${TORCH_COMPILE_BACKEND}
gradient_accumulation_steps: ${GRADIENT_ACCUMULATION_STEPS}
micro_batch_size: ${MICRO_BATCH_SIZE}
eval_batch_size: ${EVAL_BATCH_SIZE}
num_epochs: ${NUM_EPOCHS}
warmup_steps: ${WARMUP_STEPS}
warmup_ratio: ${WARMUP_RATIO}
learning_rate: ${LEARNING_RATE}
lr_quadratic_warmup: ${LR_QUADRATIC_WARMUP}
logging_steps: ${LOGGING_STEPS}
eval_steps: ${EVAL_STEPS}
evals_per_epoch: ${EVALS_PER_EPOCH}
save_strategy: ${SAVE_STRATEGY}
save_steps: ${SAVE_STEPS}
saves_per_epoch: ${SAVES_PER_EPOCH}
save_total_limit: ${SAVE_TOTAL_LIMIT}
max_steps: ${MAX_STEPS}
eval_table_size: ${EVAL_TABLE_SIZE}
eval_max_new_tokens: ${EVAL_MAX_NEW_TOKENS}
eval_causal_lm_metrics: ${EVAL_CAUSAL_LM_METRICS}
profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}
early_stopping_patience: ${EARLY_STOPPING_PATIENCE}
lr_scheduler: ${LR_SCHEDULER}
lr_scheduler_kwargs: ${LR_SCHEDULER_KWARGS}
cosine_min_lr_ratio: ${COSINE_MIN_LR_RATIO}
cosine_constant_lr_ratio: ${COSINE_CONSTANT_LR_RATIO}
lr_div_factor: ${LR_DIV_FACTOR}
optimizer: ${OPTIMIZER}
optim_args: ${OPTIM_ARGS}
optim_target_modules: ${OPTIM_TARGET_MODULES}
weight_decay: ${WEIGHT_DECAY}
adam_beta1: ${ADAM_BETA1}
adam_beta2: ${ADAM_BETA2}
adam_epsilon: ${ADAM_EPSILON}
max_grad_norm: ${MAX_GRAD_NORM}
neftune_noise_alpha: ${NEFTUNE_NOISE_ALPHA}
flash_optimum: ${FLASH_OPTIMUM}
xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}
resume_from_checkpoint: ${RESUME_FROM_CHECKPOINT}
auto_resume_from_checkpoints: ${AUTO_RESUME_FROM_CHECKPOINTS}
local_rank: ${LOCAL_RANK}
special_tokens:
bos_token: ${SPECIAL_TOKEN_BOS}
eos_token: ${SPECIAL_TOKEN_EOS}
unk_token: ${SPECIAL_TOKEN_UNK}
pad_token: ${SPECIAL_TOKEN_PAD}
tokens: ${TOKENS}
fsdp: ${FSDP}
fsdp_config: ${FSDP_CONFIG}
deepspeed: ${DEEPSPEED}
ddp_timeout: ${DDP_TIMEOUT}
ddp_bucket_cap_mb: ${DDP_BUCKET_CAP_MB}
ddp_broadcast_buffers: ${DDP_BROADCAST_BUFFERS}
torchdistx_path: ${TORCHDISTX_PATH}
pretraining_dataset: ${PRETRAINING_DATASET}
debug: ${DEBUG}
seed: ${SEED}
strict: ${STRICT}

64
.runpod/src/handler.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Runpod serverless entrypoint handler
"""
import os
import runpod
import yaml
from huggingface_hub._login import login
from train import train
from utils import get_output_dir
BASE_VOLUME = os.environ.get("BASE_VOLUME", "/runpod-volume")
if not os.path.exists(BASE_VOLUME):
os.makedirs(BASE_VOLUME)
logger = runpod.RunPodLogger()
async def handler(job):
runpod_job_id = job["id"]
inputs = job["input"]
run_id = inputs.get("run_id", "default_run_id")
args = inputs.get("args", {})
# Set output directory
output_dir = os.path.join(BASE_VOLUME, get_output_dir(run_id))
args["output_dir"] = output_dir
# First save args to a temporary config file
config_path = "/workspace/test_config.yaml"
# Add run_name and job_id to args before saving
args["run_name"] = run_id
args["runpod_job_id"] = runpod_job_id
yaml_data = yaml.dump(args, default_flow_style=False)
with open(config_path, "w", encoding="utf-8") as file:
file.write(yaml_data)
# Handle credentials
credentials = inputs.get("credentials", {})
if "wandb_api_key" in credentials:
os.environ["WANDB_API_KEY"] = credentials["wandb_api_key"]
if "hf_token" in credentials:
os.environ["HF_TOKEN"] = credentials["hf_token"]
if os.environ.get("HF_TOKEN"):
login(token=os.environ["HF_TOKEN"])
else:
logger.info("No HF_TOKEN provided. Skipping login.")
logger.info("Starting Training.")
async for result in train(config_path): # Pass the config path instead of args
logger.info(result)
logger.info("Training Complete.")
# Cleanup
del os.environ["WANDB_API_KEY"]
del os.environ["HF_TOKEN"]
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

View File

@@ -0,0 +1,61 @@
{
"input": {
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "NousResearch/Meta-Llama-3-8B",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": false,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|end_of_text|>"
}
}
}
}

45
.runpod/src/train.py Normal file
View File

@@ -0,0 +1,45 @@
"""
Runpod train entrypoint
"""
import asyncio
async def train(config_path: str, gpu_id: str = "0", preprocess: bool = True):
"""
Run preprocessing (if enabled) and training with the given config file
:param config_path: Path to the YAML config file
:param gpu_id: GPU ID to use (default: "0")
:param preprocess: Whether to run preprocessing (default: True)
"""
# First check if preprocessing is needed
if preprocess:
# Preprocess command
preprocess_cmd = (
f"CUDA_VISIBLE_DEVICES={gpu_id} axolotl preprocess {config_path}"
)
process = await asyncio.create_subprocess_shell(
preprocess_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Preprocessing: {line.decode().strip()}"
await process.wait()
yield "Preprocessing completed."
else:
yield "Skipping preprocessing step."
# Training command
train_cmd = f"axolotl train {config_path}"
process = await asyncio.create_subprocess_shell(
train_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
)
if process.stdout is not None:
async for line in process.stdout:
yield f"Training: {line.decode().strip()}"
await process.wait()

89
.runpod/src/utils.py Normal file
View File

@@ -0,0 +1,89 @@
"""
Runpod launcher utils
"""
import os
import yaml
def get_output_dir(run_id):
path = f"fine-tuning/{run_id}"
return path
def make_valid_config(input_args):
"""
Creates and saves updated config file, returns the path to the new config
:param input_args: dict of input args
:return: str, path to the updated config file
"""
# Load default config
with open("config/config.yaml", "r", encoding="utf-8") as fin:
all_args = yaml.safe_load(fin)
if not input_args:
print("No args provided, using defaults")
else:
all_args.update(input_args)
# Create updated config path
updated_config_path = "config/updated_config.yaml"
# Save updated config to new file
with open(updated_config_path, "w", encoding="utf-8") as f:
yaml.dump(all_args, f)
return updated_config_path
def set_config_env_vars(args: dict):
"""
Convert API arguments into environment variables.
Handles nested dictionaries, lists, and special values.
Args:
args (dict): The arguments dictionary from the API request
"""
def process_value(value):
"""Convert Python values to string format for environment variables"""
if value is None:
return ""
if isinstance(value, bool):
return str(value).lower()
if isinstance(value, (list, dict)):
return str(value)
return str(value)
def set_env_vars(data, prefix=""):
"""Recursively set environment variables from nested dictionary"""
for key, value in data.items():
env_key = prefix + key.upper()
# Handle special cases
if isinstance(value, dict):
# For nested dictionaries (like special_tokens)
set_env_vars(value, f"{env_key}_")
elif isinstance(value, list):
# Handle list of dictionaries (like datasets)
if value and isinstance(value[0], dict):
for i, item in enumerate(value):
set_env_vars(item, f"{env_key}_{i}_")
else:
# For simple lists (like lora_target_modules)
os.environ[env_key] = process_value(value)
else:
# Handle all other cases
os.environ[env_key] = process_value(value)
# Clear any existing related environment variables
# This prevents old values from persisting
for key in list(os.environ.keys()):
if key.startswith(
("BASE_MODEL", "MODEL_TYPE", "TOKENIZER_TYPE", "DATASET", "LORA_", "WANDB_")
):
del os.environ[key]
# Set new environment variables
set_env_vars(args)

85
.runpod/tests.json Normal file
View File

@@ -0,0 +1,85 @@
{
"input": {
"name": "quick_smoke_test_sft",
"user_id": "user",
"model_id": "llama-test",
"run_id": "llama-test",
"credentials": {
"wandb_api_key": "",
"hf_token": ""
},
"args": {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"load_in_8bit": true,
"load_in_4bit": false,
"strict": false,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca"
}
],
"val_set_size": 0.05,
"output_dir": "./outputs/lora-out",
"sequence_len": 4096,
"sample_packing": true,
"eval_sample_packing": false,
"pad_to_sequence_len": true,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": true,
"lora_modules_to_save": [
"embed_tokens",
"lm_head"
],
"gradient_accumulation_steps": 4,
"micro_batch_size": 2,
"num_epochs": 1,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"learning_rate": 0.0002,
"train_on_inputs": false,
"group_by_length": false,
"bf16": "auto",
"tf32": true,
"gradient_checkpointing": true,
"logging_steps": 1,
"flash_attention": true,
"warmup_steps": 1,
"evals_per_epoch": 1,
"eval_max_new_tokens": 128,
"saves_per_epoch": 1,
"weight_decay": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>"
}
},
"timeout": 100000
},
"config": {
"gpuTypeId": "NVIDIA GeForce RTX 4090",
"gpuCount": 1,
"containerDiskInGb": 200,
"env": [
{
"key": "TOKENIZER",
"value": ""
},
{
"key": "DISABLE_LOG_STATS",
"value": "true"
}
],
"allowedCudaVersions": [
"12.8",
"12.7",
"12.6",
"12.5",
"12.4"
]
}
}

View File

@@ -9,8 +9,7 @@ pytest -v --durations=10 -n8 \
--ignore=tests/patched/ \
--ignore=tests/cli \
/workspace/axolotl/tests/ \
--cov=axolotl \
--cov-report=xml:coverage.xml
--cov=axolotl
# Run lora kernels tests with coverage append
pytest -v --durations=10 \
@@ -51,11 +50,6 @@ pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:coverage.xml
--cov-report=xml:e2e-coverage.xml
# Upload coverage to Codecov
if [ -f e2e-coverage.xml ]; then
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION} || true

View File

@@ -28,6 +28,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -29,6 +29,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -6,13 +6,13 @@ pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \
--cov-report=xml:multigpu-coverage.xml
--cov=axolotl
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
--cov-append
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
@@ -20,8 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true

View File

@@ -1,5 +1,7 @@
codecov:
require_ci_to_pass: yes
notify:
wait_for_ci: true
coverage:
precision: 2
@@ -49,3 +51,6 @@ comment:
require_changes: no
require_base: no
require_head: yes
github_checks:
annotations: false

View File

@@ -37,3 +37,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -199,6 +199,17 @@ output_dir: # Directory to save evaluation results
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:

View File

@@ -55,46 +55,20 @@ overrides_of_model_config:
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# Quantization configuration.
quantization:
backend: bnb | hqq | gptq
bits: 8
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:
# These are default values
llm_int8_has_fp16_weight: false
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true
# If using hqq config, additional config paramters are needed. See: https://huggingface.co/docs/transformers/main/en//quantization/hqq
hqq_config:
# pick one of the following, depending on if you want to uniformly quantize the whole model or
# apply different quantization settings to specific layers in the model:
# if uniformly quantize the whole model:
group_size: 64
# if we want to invoke dynamic_config in order to apply specific layers with different quantization settings:
- nbits: 4
group_size: 64
target_modules:
- self_attn.k_proj
- self_attn.v_proj
- self_attn.o_proj
- nbits: 3
group_size: 32
target_modules:
- mlp.gate_proj
- mlp.up_proj
- mlp.down_proj
# (Internal Use Only)
# Whether you are training a 4-bit GPTQ quantized model
gptq:
gptq: true
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit:
load_in_8bit: true
# Use bitsandbytes 4 bit
load_in_4bit:
@@ -180,6 +154,10 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system
# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt
@@ -209,7 +187,7 @@ datasets:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
# See examples at `docs/dataset-formats/conversation.qmd`
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
# Note: If the below 5 fields are empty, defaults to training only on the last message.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["assistant"] # default
@@ -218,7 +196,13 @@ datasets:
# - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
train_on_eos: last
train_on_eos: turn
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
# - all: train on all EOT tokens
# - turn: train on the EOT token at the end of each trainable turn
# - last: train on the last EOT token in the conversation
# If not specified, defaults to the value of train_on_eos for backward compatibility.
train_on_eot:
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
@@ -301,8 +285,17 @@ process_reward_model:
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message. Currently only supports chatml.
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
# These tokens mark the boundaries between conversation turns.
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
# If not specified, defaults to just the model's eos_token.
# This is useful for templates that use multiple delimiter tokens.
eot_tokens:
# - "</s>"
# - "[/INST]"
# - "[/SYSTEM_PROMPT]"
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
@@ -687,8 +680,10 @@ special_tokens:
# unk_token: "<unk>"
# pad_token: "[PAD]"
# Add extra tokens.
# Optional[list[str]]. Add extra tokens to the tokenizer.
tokens:
# - "<|startoftext|>"
# - "<|endoftext|>"
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
# Only works for tokens that are not part of the base vocab (aka are added_tokens).

View File

@@ -4,18 +4,6 @@ description: Conversation format for supervised fine-tuning.
order: 3
---
## sharegpt
::: {.callout-important}
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
:::
## pygmalion
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
@@ -64,7 +52,7 @@ We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
1. (Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
@@ -109,10 +97,55 @@ datasets:
```
::: {.callout-important}
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
:::
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
5. If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
```yaml
eot_tokens:
- "[/INST]"
# - "[/SYSTEM_PROMPT]"
datasets:
- path: ...
type: chat_template
# optional
train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)
```
::: {.callout-tip}
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
:::
::: {.callout-note}
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
:::
6. Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
```yaml
eot_tokens:
- "[/INST]"
# ...
datasets:
- path: ...
type: chat_template
train_on_eos: last
train_on_eot: turn
```
::: {.callout-tip}
If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.
:::
7. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
@@ -162,3 +195,15 @@ datasets:
::: {.callout-tip}
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
## sharegpt
::: {.callout-important}
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section.
:::
## pygmalion
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```

View File

@@ -28,6 +28,8 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.7.0`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
- `main-base-py3.11-cu124-2.4.1`
@@ -50,7 +52,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
# on push to main
main-py{python_version}-cu{cuda_version}-{pytorch_version}
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
main-latest
# nightly build
@@ -68,6 +70,7 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-py3.11-cu124-2.4.1`

View File

@@ -73,10 +73,40 @@ description: Frequently asked questions
> A: This is likely an empty turn.
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
> A: There can be two reasons:
> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.
> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**
> A: There can be two reasons:
> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.
> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.
**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**
> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.
**Q: `EOT token __ is encoded as multiple tokens.`**
> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.
**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**
> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.
**Q: If `eot_tokens: ` is not provided, what happens?**
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.

View File

@@ -19,6 +19,12 @@ This guide covers all the ways you can install and set up Axolotl for your envir
## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}

View File

@@ -502,9 +502,7 @@ The input format is a simple JSON input with customizable fields based on the ab
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
using 4 GPUs - 2 for training, and 2 for vLLM:
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
::: {.callout-important}
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
@@ -539,6 +537,10 @@ Your `vLLM` instance will now attempt to spin up, and it's time to kick off trai
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
```
::: {.callout-note}
Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance.
:::
#### Reward functions
GRPO uses custom reward functions and transformations. Please have them ready locally.

View File

@@ -0,0 +1,62 @@
base_model: THUDM/GLM-4-32B-0414
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_4bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -26,3 +26,11 @@ Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @
### Llama 4 Maverick 17Bx128Experts (400B)
Coming Soon
## Delinearized Llama 4 Models
We provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```

View File

@@ -10,7 +10,6 @@ plugins:
liger_glu_activation: true
liger_rms_norm: true
liger_layer_norm: true
cut_cross_entropy: true
llama4_linearized_experts: true # needed with custom linearized experts model
load_in_4bit: true

View File

@@ -0,0 +1,69 @@
base_model: Qwen/Qwen3-32B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
lora_mlp_kernel: true
lora_qkv_kernel: true
lora_o_kernel: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: offload
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,68 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:

View File

@@ -1,4 +1,5 @@
codecov
codecov-cli
pytest
pytest-cov
pytest-retry

View File

@@ -11,18 +11,18 @@ liger-kernel==0.5.8
packaging==23.2
peft==0.15.1
peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
deepspeed>=0.15.4
trl==0.16.1
trl==0.17.0
hf_xet==1.0.0
hqq==0.2.5
optimum==1.16.2
hf_transfer
hqq==0.2.5
sentencepiece
gradio==5.23.3

View File

@@ -51,7 +51,7 @@ def parse_requirements(extras_require_map):
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
torch_version = "2.6.0" # default to torch 2.6
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -64,10 +64,16 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 6):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
extras_require_map["vllm"] = ["vllm==0.8.3"]
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.4"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.8.0"
__version__ = "0.9.0"

View File

@@ -39,16 +39,16 @@ class TrainerCliArgs:
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field(
default=1,
tensor_parallel_size: Optional[int] = field(
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: str = field(
default="0.0.0.0", # nosec B104
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
port: Optional[int] = field(
default=None,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(

View File

@@ -1,5 +1,6 @@
"""CLI to run training on a model."""
import gc
import logging
import os
from pathlib import Path
@@ -48,8 +49,11 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
gc.collect()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)

View File

@@ -932,9 +932,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = DataCollatorForSeq2Seq
kwargs["return_tensors"] = "pt"
if issubclass(collator, DataCollatorForSeq2Seq):
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
kwargs["ring_attn_func"] = training_args.ring_attn_func
return collator(
*collator_args,
@@ -1051,6 +1048,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":
@@ -1121,6 +1121,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
return training_args
def build(self, total_num_steps):

View File

@@ -371,13 +371,15 @@ class AxolotlTrainer(
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
loss = super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return loss
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}

View File

@@ -3,15 +3,29 @@ DPO trainer for axolotl
"""
import gc
import random
from functools import wraps
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union
import pandas as pd
import torch
import wandb
from accelerate import PartialState
from datasets import Dataset, IterableDataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers import Trainer
from torch.utils.data import DataLoader
from transformers import (
BaseImageProcessor,
FeatureExtractionMixin,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from trl import DPOConfig, DPOTrainer, maybe_apply_chat_template, maybe_extract_prompt
from trl.trainer.utils import log_table_to_comet_experiment
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
@@ -81,6 +95,64 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
return super().push_to_hub(*args, **kwargs)
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[
PreTrainedTokenizerBase,
BaseImageProcessor,
FeatureExtractionMixin,
ProcessorMixin,
],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs: Dict[str, Any] = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract prompt if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
# Apply the chat template if needed
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class, "tools": args.tools},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(
dataset, Dataset
): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)
return dataset
@staticmethod
def tokenize_row(
features,
@@ -124,3 +196,67 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
gc.collect()
torch.cuda.empty_cache()
return loss
# TODO: remove this once https://github.com/huggingface/trl/pull/3377 is in a release
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
# Sample and save to game log if requested (for one batch to save time)
if self.generate_during_eval:
# Generate random indices within the range of the total number of samples
num_samples = len(dataloader.dataset)
random_indices = random.sample(
range(num_samples), k=self.args.eval_batch_size
)
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
random_batch_dataset = dataloader.dataset.select(random_indices)
random_batch = self.data_collator(random_batch_dataset)
random_batch = self._prepare_inputs(random_batch)
policy_output_decoded, ref_output_decoded = (
self.generate_from_model_and_ref(self.model, random_batch)
)
table = pd.DataFrame(
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"],
policy_output_decoded,
ref_output_decoded,
)
],
)
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="game_log.csv",
table=table,
)
# Base evaluation
initial_output = super().evaluation_loop(
dataloader,
description,
prediction_loss_only,
ignore_keys,
metric_key_prefix,
)
return initial_output

View File

@@ -40,8 +40,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:
@@ -135,7 +135,9 @@ class GRPOStrategy:
try:
# use importlib to dynamically load the reward function from the module
reward_func_module_name = reward_func_fqn.split(".")[-1]
reward_func_module = importlib.import_module(reward_func_fqn.split(".")[-2])
reward_func_module = importlib.import_module(
".".join(reward_func_fqn.split(".")[:-1])
)
reward_func = getattr(reward_func_module, reward_func_module_name)
if not len(inspect.signature(reward_func).parameters) >= 2:
raise ValueError(

View File

@@ -6,4 +6,4 @@
from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin
from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin

View File

@@ -1,16 +1,86 @@
"""Module for Axolotl trainer sequence parallelism mixin"""
"""
Module for Axolotl trainer sequence parallelism mixin and training context manager
"""
import functools
import logging
import torch
import torch.distributed as dist
from datasets import Dataset
from torch import nn
from torch.utils.data import DistributedSampler, Sampler
from torch.utils.hooks import RemovableHandle
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
update_ring_attn_params,
)
LOG = logging.getLogger(__name__)
def apply_sequence_parallelism(
batch: dict[str, torch.Tensor],
local_rank: int,
local_world_size: int,
ring_attn_func: RingAttnFunc,
) -> dict[str, torch.Tensor]:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary (e.g., input_ids, attention_mask, etc.)
local_rank: Local rank in the sequence parallel group
local_world_size: World size of the sequence parallel group
ring_attn_func: The ring attention function to use
Returns:
Sliced batch dictionary.
"""
# Update ring attention params if needed
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
total_seq_len = batch["input_ids"].size(1)
for key in batch:
if (
key in batch
and isinstance(batch[key], torch.Tensor)
and batch[key].dim() > 1
and batch[key].size(1) == total_seq_len
):
if ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
# Split in sequential fashion and grab this rank's chunk
batch[key] = (
batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous()
)
elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[local_rank],
chunks[2 * local_world_size - local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, local_rank].contiguous()
return batch
class SequenceParallelMixin:
"""
Mixin class for sequence parallelism support in trainers.
@@ -87,3 +157,157 @@ class SequenceParallelMixin:
return self._create_sequence_parallel_sampler(
eval_dataset, shuffle=False, is_eval=True
)
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism
during model forward passes using a pre-forward hook, and gather outputs from
across the sequence parallelism group using a post-forward hook.
"""
def __init__(
self,
model: nn.Module,
sequence_parallel_degree: int,
ring_attn_func: RingAttnFunc,
):
self.model = model
self.sequence_parallel_degree = sequence_parallel_degree
self.ring_attn_func = ring_attn_func
self.process_group = get_ring_attn_group()
# Initialize sequence parallel group details
self.local_rank = dist.get_rank(self.process_group)
self.local_world_size = dist.get_world_size(self.process_group)
# Will store hook handles for removal
self.hook_handles: list[RemovableHandle] = []
# Create a partially applied version of the apply_sequence_parallelism function
# with pre-configured params
self.apply_sequence_parallelism = functools.partial(
apply_sequence_parallelism,
local_rank=self.local_rank,
local_world_size=self.local_world_size,
ring_attn_func=self.ring_attn_func,
)
def __enter__(self):
# Forward pre-hook to apply sequence parallelism
def sequence_parallel_pre_hook(_, args, kwargs):
# Apply sequence parallelism to kwargs
kwargs = self.apply_sequence_parallelism(batch=kwargs)
return args, kwargs
# Forward post-hook to gather outputs
def sequence_parallel_post_hook(_, __, output):
# Gather the sharded outputs
return self.gather_outputs(output)
# Register both hooks
self.hook_handles.append(
self.model.register_forward_pre_hook(
sequence_parallel_pre_hook, with_kwargs=True
)
)
self.hook_handles.append(
self.model.register_forward_hook(sequence_parallel_post_hook)
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Remove all hooks
for handle in self.hook_handles:
handle.remove()
self.hook_handles = []
def gather_outputs(self, output):
"""Gather sharded outputs from all ranks and reconstruct the full tensor."""
# Handle different output formats (dict, tensor, etc.)
if isinstance(output, dict):
gathered_output = {}
for key, value in output.items():
if isinstance(value, torch.Tensor) and value.dim() > 1:
# Gather logits or other sequence-sharded tensors
gathered_value = self.gather_tensor(value)
gathered_output[key] = gathered_value
else:
gathered_value = value.clone()
dist.all_reduce(
gathered_value, op=dist.ReduceOp.SUM, group=self.process_group
)
gathered_output[key] = gathered_value
return gathered_output
if isinstance(output, torch.Tensor):
return self.gather_tensor(output)
return output
def gather_tensor(self, tensor):
"""Gather a sharded tensor from all ranks."""
# Prepare tensors for all_gather
world_size = self.local_world_size
# Create list to store tensors from all ranks
gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
# All-gather operation
dist.all_gather(gathered_tensors, tensor, group=self.process_group)
# Concatenate along sequence dimension (typically dim=1)
if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]:
# Simple concatenation for standard sharding
return torch.cat(gathered_tensors, dim=1)
if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
# Each rank has a pattern of (rank, world_size*2-rank-1)
reconstituted_tensors = [None] * (world_size * 2)
# First, split each gathered tensor into its two chunks
for rank, gathered_tensor in enumerate(gathered_tensors):
# Each tensor contains two chunks in the sequence dimension
chunk_size = gathered_tensor.size(1) // 2
chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1)
# Place chunks in their original positions
reconstituted_tensors[rank] = chunk1
reconstituted_tensors[world_size * 2 - rank - 1] = chunk2
# Concatenate the reconstituted tensors in the correct order
return torch.cat(reconstituted_tensors, dim=1)
# Otherwise, RingAttnFunc.BATCH_STRIPE
# In striping, each rank has every world_size-th slice
batch_size = tensor.size(0)
hidden_dim = tensor.size(-1)
# First, determine the full sequence length
total_seq_len = 0
for t in gathered_tensors:
total_seq_len += t.size(1)
# Create a tensor to hold the unstriped result
result = torch.zeros(
batch_size,
total_seq_len,
hidden_dim,
dtype=tensor.dtype,
device=tensor.device,
)
# For each rank's tensor, distribute its slices to the correct positions
for rank, gathered_tensor in enumerate(gathered_tensors):
# The rank's tensor contains every world_size-th slice
# starting from its rank position
seq_len = gathered_tensor.size(1)
for i in range(seq_len):
# Calculate the position in the full tensor
pos = i * world_size + rank
if pos < total_seq_len:
result[:, pos] = gathered_tensor[:, i]
return result

View File

@@ -36,9 +36,10 @@ class BasePlugin:
Methods:
register(cfg): Registers the plugin with the given configuration.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_load(cfg, model): Performs actions after the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -77,6 +78,14 @@ class BasePlugin:
None
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is built/loaded, but before any adapters are applied.
Args:
cfg (dict): The configuration for the plugin.
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -329,9 +338,22 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_build(self, cfg, model):
"""
Calls the post_model_build method of all registered plugins after the model has been built/loaded,
but before any adapters have been applied.
Args:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
"""
for plugin in self.plugins.values():
plugin.post_model_build(cfg, model)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins.
Calls the post_model_load method of all registered plugins after the model has been loaded
inclusive of any adapters
Parameters:
cfg (dict): The configuration for the plugins.
@@ -458,6 +480,20 @@ class PluginManager:
callbacks.extend(plugin_callbacks)
return callbacks
def post_train(self, cfg, model):
"""
Calls the post_train method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_train(cfg, model)
def post_train_unload(self, cfg):
"""
Calls the post_train_unload method of all registered plugins.

View File

@@ -27,15 +27,13 @@ pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transform
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
```
## Supported Models
- llama
- llama4_text
- llama4
- llama4_text
- mllama
- phi3
- gemma
@@ -45,8 +43,15 @@ cut_cross_entropy: true
- mistral
- mistral3
- qwen2
- qwen2_moe
- qwen2_vl
- qwen2_5_vl
- qwen3
- qwen3_moe
- cohere
- cohere2
- glm
- glm4
## Citation

View File

@@ -28,7 +28,7 @@ class CutCrossEntropyArgs(BaseModel):
Input args for Cut Cross Entropy.
"""
cut_cross_entropy: Optional[bool] = None
cut_cross_entropy: Optional[bool] = True
@model_validator(mode="before")
@classmethod

View File

@@ -0,0 +1,57 @@
"""GLM 4 patch. GLM family inherits from Llama."""
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_glm(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm import modeling_glm
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm.GlmForCausalLM
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm.GlmForCausalLM.forward = cce_forward
return None
def patch_glm4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm4 import modeling_glm4
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm4.Glm4ForCausalLM
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,174 @@
"""Llama CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import (
_CONFIG_FOR_DOC,
LLAMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
_PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if hidden_states is None:
raise ValueError("hidden_states is None")
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_llama(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
"""Patch Llama for CCE."""
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.llama import modeling_llama
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_llama.LlamaForCausalLM
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_llama.LlamaForCausalLM.forward = cce_forward
return None

View File

@@ -5,9 +5,7 @@
import transformers
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
from cut_cross_entropy.transformers.llama import patch_llama
from cut_cross_entropy.transformers.phi3 import patch_phi3
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
@@ -20,6 +18,13 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
patch_llama,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
@@ -29,6 +34,22 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
patch_mistral3,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
patch_qwen2,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
patch_qwen2_5_vl,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
patch_qwen2_moe,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
patch_qwen2_vl,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
patch_qwen3_moe,
)
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"llama": patch_llama,
@@ -43,8 +64,15 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"mistral": patch_mistral,
"mistral3": patch_mistral3,
"qwen2": patch_qwen2,
"qwen2_moe": patch_qwen2_moe,
"qwen2_vl": patch_qwen2_vl,
"qwen2_5_vl": patch_qwen2_5_vl,
"qwen3": patch_qwen3,
"qwen3_moe": patch_qwen3_moe,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
"glm": patch_glm,
"glm4": patch_glm4,
}

View File

@@ -0,0 +1,37 @@
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
# pylint: disable=duplicate-code
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_qwen2(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
from transformers.models.qwen2 import modeling_qwen2
# Set the _PATCH_OPTS in the llama patch file
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
cce_forward,
)
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2.Qwen2ForCausalLM
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,246 @@
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLCausalLMOutputWithPast,
)
_PATCH_OPTS: PatchOptions | None = None
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
position_ids = position_ids.add(delta) # type: ignore
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = None
loss = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.lm_head.weight,
labels,
_PATCH_OPTS,
)
else:
logits = self.lm_head(hidden_states)
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2_5_VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
def patch_qwen2_5_vl(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
return maybe_model
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
cce_forward_multimodal
)
return None

View File

@@ -0,0 +1,188 @@
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
_CONFIG_FOR_DOC,
QWEN2MOE_INPUTS_DOCSTRING,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
_PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**loss_kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
)
hidden_states = outputs.last_hidden_state
loss = None
logits = None
if hidden_states is None:
raise ValueError("hidden_states is None")
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**loss_kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
loss.device # type: ignore
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss, # type: ignore
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def patch_qwen2_moe(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen2_moe import modeling_qwen2_moe
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(forward, maybe_model)
return maybe_model
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
return None

View File

@@ -0,0 +1,249 @@
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
):
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
cache_position[0] + self.rope_deltas
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
delta = delta.to(position_ids.device) # type: ignore
position_ids = position_ids.add(delta) # type: ignore
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = None
loss = None
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states,
self.lm_head.weight,
labels,
_PATCH_OPTS,
)
else:
logits = self.lm_head(hidden_states)
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
def patch_qwen2_vl(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen2_vl import modeling_qwen2_vl
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
return maybe_model
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
return None

View File

@@ -0,0 +1,35 @@
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
# pylint: disable=duplicate-code
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_qwen3(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
from transformers.models.qwen3 import modeling_qwen3
# Set the _PATCH_OPTS in the llama patch file
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen3.Qwen3ForCausalLM
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
return None

View File

@@ -0,0 +1,194 @@
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
_CONFIG_FOR_DOC,
QWEN3_MOE_INPUTS_DOCSTRING,
KwargsForCausalLM,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
_PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if hidden_states is None:
raise ValueError("hidden_states is None")
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
loss.device # type: ignore
) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss, # type: ignore
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def patch_qwen3_moe(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.qwen3_moe import modeling_qwen3_moe
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(forward, maybe_model)
return maybe_model
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
return None

View File

@@ -35,6 +35,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
sequence_len,
roles_to_train=None,
train_on_eos=None,
train_on_eot=None,
eot_tokens=None,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
@@ -50,6 +52,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
train_on_eot=train_on_eot,
eot_tokens=eot_tokens,
)
@property

View File

@@ -31,6 +31,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"starcoder2",
"deepseek_v2",
"deepseek_v3",
"glm",
"glm4",
]

View File

@@ -272,7 +272,7 @@ class ReLoRAScheduler(LRScheduler):
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch

View File

@@ -33,6 +33,7 @@ class ChatTemplatePrompter(Prompter):
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
field_system: str = "system",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -62,6 +63,7 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.field_system = field_system
self.tokenizer = tokenizer
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
@@ -220,10 +222,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self,
prompter: "ChatTemplatePrompter",
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
train_on_inputs: bool,
sequence_len: int,
roles_to_train: Optional[List[str]] = None,
train_on_eos: Optional[str] = None,
train_on_eot: Optional[str] = None,
eot_tokens: Optional[List[str]] = None,
split_thinking: Optional[bool] = False,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
@@ -236,12 +241,88 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
]
self.train_on_eos = train_on_eos
# Backward compatibility, load from train_on_eos
self.train_on_eot = train_on_eot if train_on_eot is not None else train_on_eos
# Default to eos_token if eot_tokens not provided
self.eot_tokens = (
eot_tokens if eot_tokens is not None else [self.tokenizer.eos_token]
)
self.split_thinking = split_thinking
self.images = "images"
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
self._validate_eot_and_eos_tokens()
def _validate_eot_and_eos_tokens(self):
"""
- Validates that EOT tokens (or eos_token) are in the chat_template
- Checks if EOT tokens are encoded as multiple tokens in the tokenizer.
- Checks for potential conflicts between train_on_eos and train_on_eot.
"""
if self.prompter.chat_template is None:
# Usually this should not happen
LOG.warning(
"No chat template provided, skipping EOT and EOS token validation"
)
return
# If the EOT token is the same as the EOS token, we need to check differently
if len(self.eot_tokens) == 1 and self.eot_tokens[0] == self.tokenizer.eos_token:
# Check if the eos_token is in the chat_template or as a variable `eos_token`
# Note: we check for `eos_token` in the string, but it could possibly not be a variable
if (
self.tokenizer.eos_token not in self.prompter.chat_template
and "eos_token" not in self.prompter.chat_template
):
LOG.warning(
f"EOS token '{self.tokenizer.eos_token}' not found in chat_template. Please check if your template/EOS token is correct."
)
return
# Create a new list to store tokens that should be kept
valid_eot_tokens = []
for token in self.eot_tokens:
# Check if EOT token is in the chat_template
if token not in self.prompter.chat_template:
LOG.warning(f"EOT token '{token}' not found in chat_template.")
# Don't add to the valid tokens list
continue
valid_eot_tokens.append(token)
# Replace the original list with the filtered one
self.eot_tokens = valid_eot_tokens
for token in self.eot_tokens:
# If token in template, check if EOT token is in tokenizer and not encoded as multiple tokens
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if not token_ids:
raise ValueError(
"EOT token encoding failed. Please check if the token is valid and can be encoded."
)
if token_ids and len(token_ids) > 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config "
"or (recommended) override unused added_tokens via `added_tokens_overrides: `."
)
# If eos_token is in eot_tokens and conflict between train_on_eos and train_on_eot, raise an error
if (
self.tokenizer.eos_token in self.eot_tokens
and self.train_on_eos != self.train_on_eot
):
raise ValueError(
"Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot"
f"train_on_eos: {self.train_on_eos}, train_on_eot: {self.train_on_eot}"
f"eot_tokens: {self.eot_tokens}"
f"eos_token: {self.tokenizer.eos_token}"
)
@property
def supports_batched(self) -> bool:
# Let calling code know we can handle lists of examples
@@ -285,6 +366,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if (
not self.roles_to_train
and not self.train_on_eos
and not self.train_on_eot
and not self.prompter.message_field_training # type: ignore
and not self.prompter.message_field_training_detail # type: ignore
):
@@ -320,6 +402,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
last_eot_idx = -1
for index, turn in enumerate(turns):
role = turn.get("role")
content = turn.get("content")
@@ -368,24 +451,45 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle EOS token
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
):
labels[eos_idx] = input_ids[eos_idx]
LOG.debug(f"EOS token set for training at index {eos_idx}")
else:
LOG.debug(
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle special tokens (EOT and EOS)
for token_type, find_func, train_option in [
("EOT", self.find_first_eot_token, self.train_on_eot),
("EOS", self.find_first_eos_token, self.train_on_eos),
]:
token_idx = find_func(input_ids, start_idx=turn_end_idx)
# Handle 'last' option for train_on_eos
if self.train_on_eos == "last" and last_eos_idx != -1:
labels[last_eos_idx] = input_ids[last_eos_idx]
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
if (
token_idx != -1 and abs(token_idx - turn_end_idx) <= 3
): # Allow for some template padding
# Update the last token index
if token_type == "EOT": # nosec B105
last_eot_idx = token_idx
else:
last_eos_idx = token_idx
# Set labels if needed for this turn
if train_option == "all" or (
train_option == "turn" and should_train
):
labels[token_idx] = input_ids[token_idx]
LOG.debug(
f"{token_type} token set for training at index {token_idx}"
)
else:
LOG.debug(
f"{token_type} token missing after turn {turn}. {token_type.lower()}_idx: {token_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for special tokens
for token_type, last_idx, train_option in [
("EOT", last_eot_idx, self.train_on_eot),
("EOS", last_eos_idx, self.train_on_eos),
]:
if train_option == "last" and last_idx != -1:
labels[last_idx] = input_ids[last_idx]
LOG.debug(
f"Last {token_type} token set for training at index {last_idx}"
)
LOG.debug(f"Final labels: {labels}")
@@ -402,6 +506,25 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return i
return -1
def find_first_eot_token(self, input_ids, start_idx):
"""Find the first EOT token in the input_ids starting from start_idx."""
# Get token IDs for all EOT tokens
eot_token_ids = []
for token in self.eot_tokens:
token_ids = self.tokenizer.encode(token, add_special_tokens=False)
if len(token_ids) != 1:
raise ValueError(
f"EOT token '{token}' is encoded as multiple tokens: {token_ids}. Please add it under `tokens: ` in the config."
)
eot_token_ids.append(token_ids[0]) # Use the last token ID if multiple
# Search for any of the EOT token IDs
for i in range(start_idx, len(input_ids)):
if input_ids[i] in eot_token_ids:
return i
return -1
def find_turn(self, turns: list[dict], turn_idx: int):
"""
Locate the starting and ending indices of the specified turn in a conversation.
@@ -488,6 +611,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
possible_sys_turn = self.transform_message(
prompt[self.prompter.field_messages][0]
)
if (
possible_sys_turn["role"] != "system"
and self.prompter.field_system in prompt
):
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
turns.append(turn)
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)
@@ -523,6 +657,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
transformed_message["role"], transformed_message["role"]
)
# TODO handle reasoning_content with split_thinking
# if the role is assistant that we want to use reasoning_content
if self.split_thinking and transformed_message["role"] == "assistant":
content = transformed_message["content"]
pairs = [("<think>", "</think>"), ("<reasoning>", "</reasoning>")]
for pair in pairs:
if pair[0] in content and pair[1] in content:
start_idx = content.find(pair[0])
end_idx = content.find(pair[1])
thinking_content = content[start_idx + len(pair[0]) : end_idx]
transformed_message["reasoning_content"] = thinking_content.strip()
transformed_message["content"] = content[
end_idx + len(pair[1]) :
].lstrip()
break
# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values
@@ -555,6 +705,9 @@ class StrategyLoader:
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
"train_on_eot": ds_cfg.get("train_on_eot", None),
"eot_tokens": cfg.get("eot_tokens", None), # loads from cfg, not ds_cfg
"split_thinking": ds_cfg.get("split_thinking", False),
}
def __call__(

View File

@@ -6,6 +6,7 @@ import os
import signal
import sys
import weakref
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict
@@ -25,6 +26,10 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainers.mixins.sequence_parallel import (
SequenceParallelContextManager,
)
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
@@ -185,16 +190,28 @@ def execute_training(
trainer: The configured trainer object.
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
# Define the context managers to use
flash_context = (
torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
)
if cfg.flash_optimum
else nullcontext()
)
sequence_parallel_context = (
SequenceParallelContextManager(
model=trainer.model,
sequence_parallel_degree=cfg.sequence_parallel_degree,
ring_attn_func=cfg.ring_attn_func,
)
if cfg.sequence_parallel_degree > 1
else nullcontext()
)
LOG.info("Starting trainer...")
with flash_context, sequence_parallel_context:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
@@ -517,4 +534,7 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer

File diff suppressed because one or more lines are too long

View File

@@ -1,20 +1,12 @@
"""
Data collators for axolotl to pad labels and position_ids for packed sequences. Also
includes logic for handling sequence parallelism collation.
"""
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
@dataclass
class DataCollatorForSeq2Seq:
@@ -49,8 +41,6 @@ class DataCollatorForSeq2Seq:
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
sequence_parallel_degree (`int`):
The degree of sequence parallelism. Default to 1 for no sequence parallelism.
"""
tokenizer: PreTrainedTokenizerBase
@@ -61,17 +51,6 @@ class DataCollatorForSeq2Seq:
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
sequence_parallel_degree: int = 1
ring_attn_func: RingAttnFunc | None = None
def __post_init__(self):
if self.sequence_parallel_degree > 1:
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
# Get information about our position in the SP group
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
def __call__(self, features, return_tensors=None):
has_attn_mask = "attention_mask" in features[0].keys()
@@ -141,62 +120,8 @@ class DataCollatorForSeq2Seq:
)
features["decoder_input_ids"] = decoder_input_ids
if self.sequence_parallel_degree > 1:
features = self.apply_sequence_parallelism(features)
return features
def apply_sequence_parallelism(
self, batch: dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Apply sequence parallelism slicing to a batch.
Args:
batch: Batch dictionary from parent collator.
Returns:
Sliced batch dictionary.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].size(1)
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
# Slice batch for sequence parallel processing
for key in batch:
if batch[key].size(1) == total_seq_len:
if self.ring_attn_func in [
RingAttnFunc.VARLEN_LLAMA3,
RingAttnFunc.BATCH_RING,
]:
batch[key] = (
batch[key]
.chunk(self.local_world_size, dim=1)[self.local_rank]
.contiguous()
)
elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG:
chunks = batch[key].chunk(2 * self.local_world_size, dim=1)
# Take rank's chunk and opposing chunk for zigzag pattern
selected_chunks = [
chunks[self.local_rank],
chunks[2 * self.local_world_size - self.local_rank - 1],
]
batch[key] = torch.cat(selected_chunks, dim=1).contiguous()
elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE:
# TODO(djsaunde): This doesn't seem to work as expected
# Split into striped data and stack
tensor = torch.stack(
batch[key].split(self.local_world_size, dim=1),
dim=1,
).transpose(1, 2)
batch[key] = tensor[:, self.local_rank].contiguous()
return batch
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):

View File

@@ -126,9 +126,6 @@ def normalize_config(cfg):
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.sequence_parallel_degree is None:
cfg.sequence_parallel_degree = 1
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step
@@ -236,18 +233,6 @@ def normalize_config(cfg):
log_gpu_memory_usage(LOG, "baseline", cfg.device)
if cfg.quantization:
if cfg.quantization.backend in ["bnb"]:
if cfg.quantization.bits == 8:
cfg.load_in_8bit = True
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
def normalize_cfg_datasets(cfg):
"""

View File

@@ -204,7 +204,37 @@ def load_prepare_preference_datasets(cfg):
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
eval_dataset = None
if cfg.val_set_size:
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=cfg.seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
eval_dataset = ds_w_test_split["test"]
train_dataset = ds_w_test_split["train"]
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)

View File

@@ -134,10 +134,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")

View File

@@ -1,5 +1,7 @@
"""custom checkpointing utils"""
from functools import partial
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
@@ -9,6 +11,10 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer.__self__,
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)

View File

@@ -36,7 +36,6 @@ from transformers import (
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
GPTQConfig,
HqqConfig,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
@@ -54,6 +53,7 @@ from transformers.integrations.deepspeed import (
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -75,6 +75,7 @@ from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
@@ -572,10 +573,8 @@ class ModelLoader:
patch_gemma3conditionalgeneration_forward()
# load any patches from plugins
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
PLUGIN_MANAGER.pre_model_load(self.cfg)
# monkey patch to allow additional Accelerator init kwargs
if self.cfg.fp8:
@@ -834,13 +833,6 @@ class ModelLoader:
del self.model_kwargs["device_map"]
def set_quantization_config(self) -> None:
if (
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
@@ -862,21 +854,21 @@ class ModelLoader:
and hasattr(self.model_config, "quantization_config")
and self.model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
and not self.cfg.hqq
):
quant_config_class_dict = {
"gptq": GPTQConfig,
"awq": AwqConfig,
"bitsandbytes": BitsAndBytesConfig,
}
quant_config_class = quant_config_class_dict[
self.model_config.quantization_config["quant_method"]
]
self.model_kwargs["quantization_config"] = quant_config_class(
**self.model_config.quantization_config
)
if self.model_config.quantization_config["quant_method"] == "gptq":
self.model_kwargs["quantization_config"] = GPTQConfig(
**self.model_config.quantization_config
)
elif self.model_config.quantization_config["quant_method"] == "awq":
self.model_kwargs["quantization_config"] = AwqConfig(
**self.model_config.quantization_config
)
elif (
self.model_config.quantization_config["quant_method"] == "bitsandbytes"
):
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
@@ -894,8 +886,8 @@ class ModelLoader:
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
if self.cfg.bnb_config_kwargs:
bnb_config.update(self.cfg.bnb_config_kwargs)
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
@@ -911,13 +903,6 @@ class ModelLoader:
**bnb_config,
)
if self.cfg.hqq:
from axolotl.utils.schemas.quant import get_hqq_quant_config_kwargs
self.model_kwargs["quantization_config"] = HqqConfig(
**get_hqq_quant_config_kwargs(self.cfg)
)
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
self.model_kwargs.pop("load_in_8bit", None)
@@ -1051,12 +1036,6 @@ class ModelLoader:
config=self.model_config,
)
else:
if self.cfg.hqq and torch.cuda.device_count() < 2:
# for some reason on single gpu, we need to set device_map to auto/cuda
# otherwise you run into tensors on two devices error during training
# Doesn't affect multi-gpu tho
self.model_kwargs["device_map"] = "auto"
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
@@ -1211,7 +1190,7 @@ class ModelLoader:
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit or self.cfg.hqq)
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
@@ -1273,6 +1252,7 @@ class ModelLoader:
try:
skip_move_to_device = self.build_model(qlora_fsdp)
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
except Exception as err: # pylint: disable=broad-exception-caught
LOG.exception(err)
raise err
@@ -1352,6 +1332,8 @@ class ModelLoader:
before_kbit_train_or_finetune=False,
)
PLUGIN_MANAGER.pre_lora_load(self.cfg, self.model)
# ---------------------------------------------------------
# load lora or adapter
# ---------------------------------------------------------
@@ -1413,7 +1395,7 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
# TODO resume_from_checkpoint handling
PLUGIN_MANAGER.post_model_load(self.cfg, self.model)
return self.model, lora_config
@@ -1448,9 +1430,13 @@ def load_adapter(model, cfg, adapter, inference=False):
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
model, lora_config = load_lora(model, cfg, inference=inference)
PLUGIN_MANAGER.post_lora_load(cfg, model)
return model, lora_config
if adapter == "llama-adapter":
return load_llama_adapter(model, cfg)
model, lora_config = load_llama_adapter(model, cfg)
PLUGIN_MANAGER.post_lora_load(cfg, model)
return model, lora_config
raise NotImplementedError(f"{adapter} peft adapter not available")
@@ -1481,16 +1467,7 @@ def load_llama_adapter(model, cfg):
def find_all_linear_names(model):
from hqq.core.peft import HQQLinearLoRA
from hqq.core.quantize import HQQLinear
cls = (
bnb.nn.Linear4bit,
bnb.nn.Linear8bitLt,
torch.nn.Linear,
HQQLinear,
HQQLinearLoRA,
)
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
lora_module_names = set()
for name, module in model.named_modules():
if (

View File

@@ -309,6 +309,7 @@ class AxolotlInputConfig(
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
) | None = None
chat_template_jinja: str | None = None
eot_tokens: list[str] | None = None
default_system_message: str | None = None
fix_untrained_tokens: int | list[int] | None = None
@@ -1149,22 +1150,17 @@ class AxolotlInputConfig(
return data
@field_validator("sequence_parallel_degree", mode="after")
@classmethod
def check_sequence_parallel_degree(cls, value, info):
if not value:
value = 1
if value > 1:
if not info.data.get("flash_attention"):
@model_validator(mode="after")
def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree:
self.sequence_parallel_degree = 1
elif self.sequence_parallel_degree > 1:
if not self.flash_attention:
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if (
info.data.get("sample_packing")
and not info.data["micro_batch_size"] == 1
):
if self.sample_packing and self.micro_batch_size > 1:
raise ValueError(
"micro_batch_size must be set to 1 when sample_packing is enabled"
"due to a `ring-flash-attn` requirement"
@@ -1184,42 +1180,40 @@ class AxolotlInputConfig(
# according to the proportion of non-padding tokens per rank.
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"sequence_parallel_degree={value}. Please note that logged losses may "
"differ slightly to the non-SP losses due to transformers Trainer "
"implementation details. Please see "
"https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return value
return self
@field_validator("ring_attn_func", mode="after")
@classmethod
def check_ring_attn_func(cls, value, info):
if not info.data.get("sequence_parallel_degree", 1) > 1:
return value
@model_validator(mode="after")
def validate_ring_attn_func(self):
if getattr(self, "sequence_parallel_degree", 1) == 1:
return self
from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc
if value is not None:
# Set the ring attention function if passed in config
if self.ring_attn_func is not None:
valid_funcs = list(RingAttnFunc)
if value in valid_funcs:
value = RingAttnFunc(value)
if self.ring_attn_func in valid_funcs:
self.ring_attn_func = RingAttnFunc(self.ring_attn_func)
else:
raise ValueError(
f"ring_attn_func: {value} must be one of {valid_funcs}"
f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}"
)
else:
# Default ring attention function selection
sample_packing = info.data.get("sample_packing")
value = (
sample_packing = getattr(self, "sample_packing", False)
self.ring_attn_func = (
RingAttnFunc.VARLEN_LLAMA3
if sample_packing
else RingAttnFunc.BATCH_RING
)
return value
return self
@model_validator(mode="before")
@classmethod

View File

@@ -50,6 +50,7 @@ class SFTDataset(BaseModel):
message_property_mappings: dict[str, str] | None = None
message_field_training: str | None = None
message_field_training_detail: str | None = None
split_thinking: bool | None = None
logprobs_field: str | None = None
temperature: float | None = None
roles_to_train: list[str] | None = None

View File

@@ -35,6 +35,7 @@ class ChatTemplate(str, Enum):
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
qwen3 = "qwen3" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name

View File

@@ -1,8 +1,8 @@
"""Pydantic models for PEFT-related configuration"""
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Any
from axolotl.utils.schemas.quant import QuantizationConfig
from pydantic import BaseModel, Field, field_validator, model_validator
class LoftQConfig(BaseModel):
@@ -23,11 +23,8 @@ class PeftConfig(BaseModel):
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
quantization: QuantizationConfig | None = None
load_in_4bit: bool | None = None # for internal use
load_in_8bit: bool | None = None # for internal use
hqq: bool | None = None # for internal use
gptq: bool | None = None # for internal use
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
adapter: str | None = None
lora_model_dir: str | None = None
@@ -53,6 +50,8 @@ class LoraConfig(BaseModel):
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
loraplus_lr_ratio: float | None = Field(
default=None,
@@ -75,11 +74,11 @@ class LoraConfig(BaseModel):
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("quantization"))
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
):
raise ValueError(
"Quantization is not supported without setting an adapter."
"If you want to full finetune, please turn off Quantization."
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@@ -87,26 +86,25 @@ class LoraConfig(BaseModel):
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
if self.quantization.bits == 8 or self.load_in_8bit:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if using gptq")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.quantization.bits == 4 or self.load_in_4bit:
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.quantization:
if self.quantization.bits == 8 or self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4 or self.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@@ -123,24 +121,6 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -1,93 +0,0 @@
""" "
Takes care of quantization configuration
"""
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from pydantic import BaseModel, Field, model_validator
class HQQConfig(BaseModel):
"""HQQ configuration subset"""
nbits: Literal[8, 4, 3, 2, 1] | None = Field(
default=None,
json_schema_extra={
"description": "Number of bits for HQQ quantization. 8, 4, 3, 2, or 1."
},
)
group_size: int = Field(default=64)
target_modules: list[str] | str | None = Field(
default=None,
json_schema_extra={
"description": "Target modules for HQQ quantization. If not specified, the whole model will be quantized."
},
)
class QuantizationConfig(BaseModel):
"""Over all Quantization configuration subset"""
# We will use this class as base future refactoring of all quantization configs
backend: Literal["bnb", "hqq", "gptq"] | None = None
bits: Literal[8, 4, 3, 2, 1] | None = None
bnb_config_kwargs: dict[str, Any] | None = None
hqq_config: HQQConfig | Annotated[list[HQQConfig], MinLen(1)] | None = None
@model_validator(mode="before")
@classmethod
def check_hqq_config(cls, data):
if data.get("backend") == "hqq" and not data.get("hqq_config"):
raise ValueError("If using HQQ, must set `group_size` under `hqq_config`")
if data.get("hqq_config") and len(data.get("hqq_config")) > 1:
for hqq_config in data.get("hqq_config"):
if hqq_config.get("target_modules") is None:
raise ValueError(
"For list of hqq configs, `target_modules` must be specified for each"
)
return data
def get_hqq_quant_config_kwargs(cfg):
# If no target module is specified, then target the whole model
if not isinstance(cfg.quantization.hqq_config, list):
cfg.quantization.hqq_config = [cfg.quantization.hqq_config]
if (
len(cfg.quantization.hqq_config) == 1
and cfg.quantization.hqq_config[0].target_modules is None
):
nbits = (
cfg.quantization.hqq_config[0].nbits
if cfg.quantization.hqq_config[0].nbits is not None
else cfg.quantization.bits
)
return {
"nbits": nbits,
"group_size": cfg.quantization.hqq_config[0].group_size,
}
hqq_quant_config_kwargs = {"dynamic_config": {}}
for hqq_config in cfg.quantization.hqq_config:
nbits = (
hqq_config.nbits if hqq_config.nbits is not None else cfg.quantization.bits
)
target_modules = hqq_config.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
hqq_quant_config_kwargs["dynamic_config"][module] = {
"nbits": nbits,
"group_size": hqq_config.group_size,
}
return hqq_quant_config_kwargs

View File

@@ -36,3 +36,11 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)
host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host for the vLLM server to start on"},
)
port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to start on"},
)

View File

@@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing or cfg.sequence_parallel_degree > 1:
elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
@@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1:
if cfg.eval_sample_packing:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
@@ -528,6 +528,13 @@ def setup_torch_compile_env(cfg):
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
from axolotl.utils.distributed import distributed_state
if distributed_state and distributed_state.initialized:
raise RuntimeError(
"Distributed State already initialized before Deepspeed setup"
)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:

View File

@@ -79,9 +79,9 @@ def download_smollm2_135m_model():
@pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model():
def download_smollm2_135m_gptq_model():
# download the model
snapshot_download_w_retry("JackFram/llama-68m", repo_type="model")
snapshot_download_w_retry("lilmeaty/SmolLM2-135M-Instruct-GPTQ", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
@@ -90,6 +90,12 @@ def download_qwen_2_5_half_billion_model():
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_qwen3_half_billion_model():
# download the model
snapshot_download_w_retry("Qwen/Qwen3-0.6B", repo_type="model")
@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset

View File

@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
# pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration:
"bf16": "auto",
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration:
attention_type: True,
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -0,0 +1,184 @@
"""
e2e tests to make sure all the hooks are fired on the plugin
"""
import os
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import BasePlugin
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
class LogHooksPlugin(BasePlugin):
"""
fixture to capture in a log file each hook that was fired
"""
base_dir = Path("/tmp/axolotl-log-hooks")
def __init__(self):
self.base_dir.mkdir(parents=True, exist_ok=True)
try:
os.remove(self.base_dir.joinpath("plugin_hooks.log"))
except FileNotFoundError:
pass
def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_model_load\n")
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_build\n")
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_lora_load\n")
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_lora_load\n")
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_load\n")
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_optimizer\n")
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("get_trainer_cls\n")
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_lr_scheduler\n")
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_pre_trainer\n")
return []
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_post_trainer\n")
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_train\n")
def post_train_unload(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_train_unload\n")
class TestPluginHooks:
"""
e2e tests to make sure all the hooks are fired during the training
"""
def test_plugin_hooks(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"tests.e2e.integrations.test_hooks.LogHooksPlugin",
],
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
with open(
"/tmp/axolotl-log-hooks" + "/plugin_hooks.log", "r", encoding="utf-8"
) as f:
file_contents = f.readlines()
file_contents = "\n".join(file_contents)
assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents
assert "post_lora_load" in file_contents
assert "post_model_load" in file_contents
# assert "create_optimizer" in file_contents # not implemented yet
assert "get_trainer_cls" in file_contents
# assert "create_lr_scheduler" in file_contents # not implemented yet
assert "add_callbacks_pre_trainer" in file_contents
assert "add_callbacks_post_trainer" in file_contents
assert "post_train" in file_contents
# assert "post_train_unload" in file_contents # not called from test train call
try:
os.remove("/tmp/axolotl-log-hooks" + "/plugin_hooks.log")
except FileNotFoundError:
pass

View File

@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
@@ -54,6 +54,7 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -4,11 +4,14 @@ GRPO test suite
import os
import random
import shutil
import subprocess # nosec B404
import sys
import tempfile
import time
from pathlib import Path
import psutil
import pytest
import requests
import yaml
@@ -21,8 +24,8 @@ from tests.e2e.utils import require_vllm
def start_vllm(
model: str, env: dict | None = None, wait: int | None = None, quiet=False, **kwargs
) -> int:
model: str, env: dict, wait: int | None = None, quiet=False, **kwargs
) -> subprocess.Popen:
"""
helper function to start the VLLM server in the background, mostly for testing purposes
"""
@@ -46,10 +49,41 @@ def start_vllm(
# print out the command to be executed
print(" ".join(cmd))
vllm_logging_json = Path(tempfile.mkdtemp()) / "vllm_logging.json"
with open(vllm_logging_json, "w", encoding="utf-8") as temp_file:
temp_file.write(
"""{
"formatters": {
"json": {
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
}
},
"handlers": {
"file": {
"class": "logging.FileHandler",
"formatter": "json",
"level": "DEBUG",
"filename": "/tmp/vllm.log",
"mode": "a"
}
},
"loggers": {
"vllm": {
"handlers": ["file"],
"level": "DEBUG",
"propagate": false
}
},
"version": 1
}"""
)
cmd_env = env.copy()
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
# start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with
cmd,
env=env,
env=cmd_env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
stderr=subprocess.DEVNULL if quiet else subprocess.PIPE,
) # nosec B603
@@ -58,32 +92,51 @@ def start_vllm(
print(f"VLLM server process started (PID: {process.pid})")
# wait until the http server is ready, even if it 404s, but timeout after 60 seconds
period_seconds = 5
started = False
if wait and host and port:
for _ in range(int(wait)):
for i in range(0, int(wait), period_seconds):
try:
response = requests.get(f"http://{host}:{port}", timeout=1)
print(f"{i}: VLLM server (status: {response.status_code})")
if int(response.status_code) in [200, 404]:
started = True
break
except requests.exceptions.RequestException:
pass
except requests.exceptions.RequestException as exc:
print(f"{i}: VLLM server failed to start: {str(exc)}")
# also check if the process.pid is still running
if not process.poll() is None:
break
time.sleep(1)
time.sleep(period_seconds)
if wait and not started:
print(
f"VLLM server process did not start within {wait} seconds. Please check your server logs."
)
process.kill()
recursive_kill(process)
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process id
return process.pid
# return the process
return process
def recursive_kill(process: subprocess.Popen):
"""
Recursively kill a process and its children
"""
process = psutil.Process(process.pid)
for child in psutil.Process(process.pid).children(recursive=True):
child.terminate()
child.kill()
os.kill(child.pid, 9)
process.terminate()
process.kill()
os.kill(process.pid, 9)
class TestGRPO:
@@ -174,16 +227,17 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
"NCCL_P2P_LEVEL": "NVL",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -202,10 +256,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
os.kill(vllm_process_id, 9)
recursive_kill(vllm_process)
@pytest.mark.parametrize(
"num_gpus",
@@ -262,16 +320,17 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_USE_V1": "0",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process_id = start_vllm(
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=120,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
@@ -290,7 +349,11 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={"NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env},
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
os.kill(vllm_process_id, 9)
recursive_kill(vllm_process)

View File

@@ -30,10 +30,8 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
@@ -101,10 +99,8 @@ class TestMultiGPUEval:
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",

View File

@@ -171,10 +171,7 @@ class TestMultiGPULlama:
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
@@ -252,10 +249,7 @@ class TestMultiGPULlama:
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
@@ -554,10 +548,7 @@ class TestMultiGPULlama:
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
@@ -657,10 +648,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
}
else:
adapter = {}
@@ -734,10 +722,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
}
else:
adapter = {}
@@ -811,10 +796,7 @@ class TestMultiGPULlama:
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
}
else:
adapter = {}

View File

@@ -28,10 +28,7 @@ class TestMultiGPUQwen2:
cfg = DictDefault(
{
"base_model": base_model,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -28,7 +28,7 @@ class Test4dMultipackLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": False,
"sdp_attention": True,
"sample_packing": True,
@@ -41,6 +41,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
@@ -60,6 +63,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -72,7 +76,7 @@ class Test4dMultipackLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": False,
"sdp_attention": False,
"sample_packing": True,
@@ -85,6 +89,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
@@ -104,6 +111,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -0,0 +1,77 @@
"""
E2E tests for activation checkpointing
"""
import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@pytest.fixture()
def fix_checkpoint_after_test():
yield
transformers.modeling_utils.checkpoint = checkpoint
class TestActivationCheckpointing:
"""
E2E tests for activation checkpointing
"""
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -32,10 +32,7 @@ class TestFalconPatched(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
@@ -66,6 +63,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -106,6 +104,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -32,7 +32,7 @@ class TestFusedLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
@@ -41,9 +41,7 @@ class TestFusedLlama(unittest.TestCase):
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -67,6 +65,7 @@ class TestFusedLlama(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -31,8 +31,8 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 16384,
"sample_packing": False,
"flash_attention": True,
@@ -44,7 +44,9 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {},
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "Yukang/LongAlpaca-12k",
@@ -65,6 +67,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -77,14 +80,16 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 16384,
"sample_packing": False,
"flash_attention": True,
"s2_attention": True,
"val_set_size": 0.02,
"special_tokens": {},
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "Yukang/LongAlpaca-12k",
@@ -105,6 +110,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -31,8 +31,8 @@ class TestLoraLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
@@ -44,9 +44,7 @@ class TestLoraLlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.2,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -70,6 +68,7 @@ class TestLoraLlama(unittest.TestCase):
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -83,15 +82,12 @@ class TestLoraLlama(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,
@@ -102,9 +98,7 @@ class TestLoraLlama(unittest.TestCase):
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -123,6 +117,7 @@ class TestLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -33,10 +33,7 @@ class TestMixtral(unittest.TestCase):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"quantization": {
"backend": "bnb",
"bits": 4,
},
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 16,
"lora_alpha": 32,
@@ -63,6 +60,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -101,6 +99,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,7 +6,7 @@ import unittest
import transformers
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
@@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase):
"sample_packing": True,
"flash_attention": True,
"pad_to_sequence_len": True,
"load_in_8bit": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 64,
"lora_alpha": 32,
@@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
@@ -68,6 +68,7 @@ class TestResumeLlama:
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -2,14 +2,19 @@
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.dict import DictDefault
@@ -47,6 +52,27 @@ def fixture_cfg():
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@@ -73,11 +99,6 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
@@ -101,88 +122,303 @@ class TestRingAttention:
set_ring_attn_group(None)
# Mock a simplified DataCollator test
@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_sequence_parallel_slicing(
mock_world_size, mock_rank, mock_get_group, partial_state
):
"""Test the basic sequence slicing logic without full collator instantiation."""
# Setup mocks
mock_get_group.return_value = MagicMock()
mock_rank.return_value = 1 # Second GPU
mock_world_size.return_value = 4 # 4 GPUs total
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
# Create a sample batch
batch = {
"input_ids": torch.tensor(
[
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
]
),
"attention_mask": torch.ones(2, 12),
}
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
# Simplified slicing logic from SequenceParallelDataCollator
def slice_batch(batch, rank, world_size):
result = {}
for key in batch:
seq_len = batch[key].shape[1]
slice_size = seq_len // world_size
start_idx = rank * slice_size
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
result[key] = batch[key][:, start_idx:end_idx]
return result
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
# Slice the batch
result = slice_batch(
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
)
# Check slicing
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
expected_input_ids = torch.tensor(
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
[104, 105, 106], # Second slice of first sequence
[204, 205, 206], # Second slice of second sequence
]
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
],
)
assert torch.all(result["input_ids"] == expected_input_ids)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value)
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks
from axolotl.utils.schemas.config import AxolotlInputConfig
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
# Valid configuration: sequence_parallel_degree > 1 and flash_attention is True
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
}
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Should validate without errors
config = AxolotlInputConfig(**cfg)
assert config.sequence_parallel_degree == 2
assert config.flash_attention is True
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
def test_config_validation_with_invalid_inputs(cfg):
"""Test that invalid sequence parallelism configurations fail validation."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
MagicMock,
)
# Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False
cfg = cfg | {
"sequence_parallel_degree": 2,
"flash_attention": False,
}
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
def test_world_size_one(self, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
result = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify error message
assert "flash_attention: true must be set" in str(excinfo.value)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
def test_batch_zigzag(self, sequence_parallel_batch):
"""Test BATCH_ZIGZAG sharding pattern."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
seq_len = batch["input_ids"].size(1)
# Test rank 0
result_rank0 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Test rank 1
result_rank1 = apply_sequence_parallelism(
batch={k: v.clone() for k, v in batch.items()},
local_rank=1,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
)
# Checks for both ranks
assert result_rank0["input_ids"].shape[1] == seq_len // 2
assert result_rank1["input_ids"].shape[1] == seq_len // 2
# For a 2-rank system with 8 tokens, check specific zigzag pattern
# Rank 0 should get chunks [0, 1] and [6, 7]
# Rank 1 should get chunks [2, 3] and [4, 5]
if seq_len == 8:
# Create expected tensors for comparison
rank0_expected = torch.cat(
[original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
)
rank1_expected = torch.cat(
[original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
)
assert torch.equal(result_rank0["input_ids"], rank0_expected)
assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
"""Test that we can create a partially applied version of the function."""
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" not in result
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@@ -72,6 +72,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -122,6 +123,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -177,6 +179,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -34,10 +34,7 @@ class TestReLoraLlama(unittest.TestCase):
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"quantization": {
"backend": "bnb",
"bits": 8,
},
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,

Some files were not shown because too many files have changed in this diff Show More