Compare commits

...

38 Commits

Author SHA1 Message Date
Dan Saunders
6f6d917a99 Revert "checkpoint model on first step callback (#2906)"
This reverts commit 10ba1622f7.
2025-07-15 15:01:12 -04:00
Dan Saunders
10ba1622f7 checkpoint model on first step callback (#2906)
* checkpoint model on first step callback

* remove debug

* add test cases; update existing tests not to save on first step

* move test out of solo

* delete

* default to False

* typo
2025-07-15 15:00:48 -04:00
Wing Lian
d320ef6199 fix for upstream refactor of KwargsForCausalLM (#2911) 2025-07-15 11:28:41 -04:00
NanoCode012
354eaaf0d3 feat: add call method to mistral tokenizer wrapper (#2898) 2025-07-14 22:33:35 -04:00
greenhestu
a061446540 Fix: Prevents merging of tool arguments during preprocessing (#2909) 2025-07-14 22:33:10 -04:00
Wing Lian
cd079b5536 Tensor parallel w DeepSpeed AutoTP (#2574)
* support for deepspeed autotup

* bump to latest deepspeed that supports deepcompile too

* add deepcompile support too

* fix total steps calculation for TP

* setup fixture for tp

* update ds config to ensure weights are gathered for checkpoint

* fix duplicate validation names

* chore: lint
2025-07-14 21:33:48 -04:00
Wing Lian
5cc16040a8 move the plugin post trainer create to the setup trainer (#2907)
* move the plugin post trainer create to the setup trainer

* move post-train plugins to execute-training fn
2025-07-14 20:11:33 -04:00
Wing Lian
38359a8997 allow profiling in mid-training rather from the start (#2899) [skip ci]
* allow profiling in mid-training rather from the start

* simplify based on PR feedback

* fix logic, improve saving at end, add tests
2025-07-14 20:11:11 -04:00
Wing Lian
7dc3ac6cb3 update nightlies builds (#2921) [skip ci] 2025-07-14 20:10:43 -04:00
Wing Lian
99187cd208 Activation Offloading w CUDA Streams (#2900) [skip ci]
* use cuda streams for activation offloading

* use torch native ops

* update cfg schema for streams

* fix literal constructor for set

* use context for training step so it doesn't affect evals

* disable streams

* auto gc on eval steps

* use activation_offloading config arg

* add docs for gradient checkpointing

* handle validation for gc/ao

* use cuda streams for act offloading

* add more validation for AC w/o GC

* fix docs

* move activation_offloading lower in definition so it doesn't break args/kwargs

* fix kd due to import order
2025-07-14 20:10:20 -04:00
Wing Lian
aa684122f1 upgrade peft==0.16.0 and datasets==4.0.0 (#2917) [skip ci]
* upgrade peft to 0.16.0

* upgrade datasets to 4.0.0

* refactor dupes from merge/rebase

* fix check for fsdp1 + sharded_state_dict

* use full state dict for ci
2025-07-14 20:09:26 -04:00
Wing Lian
ca4d4ef793 don't init distributed for deepspeed if preprocessing (#2920)
* don't init distributed for deepspeed if preprocessing

* add e2e test to validate preprocess cli with deepspeed

* ignore duplicate code for cfg
2025-07-14 14:19:19 -04:00
Dan Saunders
37edbe4999 Remove extra torch.compile call (#2904)
* debug

* debug

* debug

* moving validation code to transformers

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint
2025-07-14 12:32:45 -04:00
Wing Lian
e581c15d40 refactor dupes from merge/rebase (#2919) [skip ci] 2025-07-14 10:05:26 -04:00
Wing Lian
af92151a7b FSDP2 fix validation and add tests (#2910)
* fix validation and add tests

* remove debugging and add more tests

* remove migrate_fsdp
2025-07-14 09:25:44 -04:00
Wing Lian
80dc4c261a fix xformers version for python 2.6 (#2916) [skip ci] 2025-07-14 09:24:29 -04:00
Wing Lian
7ccbbd8e77 upgrade liger to 0.6.0 (#2893) [skip ci] 2025-07-14 09:24:07 -04:00
Wing Lian
5081db7f8a upgrade trl==0.19.1 (#2892) [skip ci]
* upgrade trl==0.19.1

* add vllm for tests for grpo

* fixes to work with latest trl

* need data_parallel_size config too

* support for vllm_mode for server / colocate

* vllm settings for colocate

* relax vllm version

* bump min hf hub for latest vllm support

* add hints on string literal for vllm mode

* use latest transformers 4.53.2

* tweak acceptable loss on flaky test_ds_zero3_packed test

* don't run flaky vllm/grpo tests for now
2025-07-14 09:23:42 -04:00
Wing Lian
41664c7c4c fix ddp for incorrect steps (#2915)
* fix ddp for incorrect steps

* add test
2025-07-14 07:51:16 -04:00
Wing Lian
9a8073e73d Liquid Foundation Model 2 support (#2905)
* LFM2 support

* docs

* packing seems to work

* update install to force install in case already on dev version

* default to use chunked cross entropy
2025-07-12 11:41:34 -04:00
Jiawei Liu
7fb8441e0e fix: customized dataset with simpo (#2894) [skip ci] 2025-07-12 11:40:30 -04:00
NanoCode012
4dc5910e1c feat(doc): re-add docker 2.7.0 tag back (#2902) [skip ci] 2025-07-12 11:40:01 -04:00
Wing Lian
fb7bc9250d move unmaintained examples to archive (#2903) [skip ci] 2025-07-12 11:39:51 -04:00
salman
d6e4a611e5 FSDP1 -> FSDP2 (#2760)
* FSDP2 args migration implementation

This commit implements the migration to FSDP2 arguments including:
- FSDP2 support with LoRA training
- DPO integration with FSDP2
- Model loading fixes and refactoring
- CPU offloading and PEFT handling
- Test updates and CI improvements
- Bug fixes for dtype errors and various edge cases
2025-07-12 15:18:01 +01:00
Ed Sealing
eb662557a7 Register Plugins in Ray Workers (#2901) [skip ci]
* Access plugins in ray cluster

* Add comment

* chore: lint

---------

Co-authored-by: Ed Sealing <ed.sealing@patapsco.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-11 16:59:59 -04:00
salman
03b2a113fe Update doc preview workflow to use sticky comments (#2873) 2025-07-11 14:08:35 +01:00
NanoCode012
9b95a625ab feat: add devstral small 2507 (#2896)
* feat: add devstral small 2507

* chore: update blog doc
2025-07-11 09:34:19 +07:00
Wing Lian
c370d0795c [doc] Fix docs for text field mapping for completion datasets (#2890)
* Fix docs for text field mapping for completion datasets

* update another reference
2025-07-09 14:52:44 -04:00
Wing Lian
76aeb16156 tiled_mlp supports single gpu (#2891)
* tiled_mlp supports single gpu

* use checkpoint offloading for arctic training

* patch torch checkpoint too

* support for single gpu zero3

* add linkback to where it was copied from
2025-07-09 12:48:22 -04:00
Wing Lian
7c5ea0010f bump dev version (#2889) [skip ci] 2025-07-09 09:43:42 -04:00
Wing Lian
c6d69d5c1b release v0.11.0 (#2875)
Some checks failed
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 128, 12.8.1, 3.11, 2.7.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 126, 12.6.3, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
* release v0.11.0

* don't build vllm into release for now

* remove 2.5.1 references

* smollm3 multipack support

* fix ordering of e2e tests
2025-07-09 09:22:35 -04:00
Wing Lian
4ff96a2526 fix xformers version (#2888) 2025-07-09 08:43:40 -04:00
salman
89e99eaaa7 slowest durations (#2887) [skip ci] 2025-07-09 08:43:26 -04:00
Wing Lian
6ed501f6dc add 2.7.0 torch images back to support vlllm (#2885) 2025-07-08 16:28:14 -04:00
NanoCode012
8c6a6ea6eb Feat: add devstral model support (#2880) [skip ci]
* fix: do not add training and training_detail block by default

* fixed: magistral docs

* fix: address pad adding new fields and use built-in from_openai

* feat: try enable multiprocessing

* fix: check for keys before deleting attn_mask

* feat: add mistral pad test

* feat: add tool calling test

* feat: add devstral tokenizer tests

* fix: comma format

* chore: remove unused support_preprocessing as tokenizer is pickable now

* chore: update magistral doc

* feat: add devstral readme and example

* chore: refactor error handling
2025-07-08 11:01:19 -04:00
NanoCode012
78bff4925e fix: set add_generation_prompt to False when apply chat template (#2859) [skip ci] 2025-07-08 11:00:44 -04:00
NanoCode012
b237c8a3f3 chore: update cce commit to include gemma3n fixes (#2881) [skip ci] 2025-07-08 10:59:35 -04:00
float-trip
1032e22650 Fix link in FSDP + QLoRA docs. (#2879) [skip ci] 2025-07-08 09:19:09 -04:00
144 changed files with 3384 additions and 879 deletions

View File

@@ -29,11 +29,11 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "124"
cuda_version: 12.4.1
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
@@ -43,7 +43,7 @@ jobs:
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"

View File

@@ -15,15 +15,15 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
@@ -82,17 +82,17 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"

View File

@@ -33,11 +33,11 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
pytorch: 2.7.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126

View File

@@ -12,16 +12,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -65,16 +65,16 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -28,6 +28,8 @@ jobs:
steps:
- name: Check out repository
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2
@@ -50,10 +52,11 @@ jobs:
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
id: netlify
with:
publish-dir: './_site'
enable-pull-request-comment: true
enable-github-deployment: true
enable-pull-request-comment: false
enable-github-deployment: false
github-token: ${{ secrets.GITHUB_TOKEN }}
deploy-message: "Deployed On Netlify"
github-deployment-environment: 'preview'
@@ -61,3 +64,13 @@ jobs:
env:
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
- name: Update PR with preview link
if: ${{ steps.netlify.outcome == 'success' }}
uses: marocchino/sticky-pull-request-comment@v2
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
message: |
📖 **Documentation Preview**: ${{ steps.netlify.outputs.deploy-url }}
Deployed on Netlify from commit ${{ github.event.pull_request.head.sha }}

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
pytorch_version: ["2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -80,9 +80,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |

View File

@@ -52,7 +52,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -102,9 +102,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -125,7 +125,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -175,9 +175,9 @@ jobs:
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v tests/patched/
pytest -v tests/cli/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
@@ -198,7 +198,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 126
@@ -252,18 +252,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: llmcompressor
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1

View File

@@ -97,7 +97,7 @@
# # 'no_input_format' cannot include {input}
# no_input_format: "{instruction} "
# # For `completion` datsets only, uses the provided field instead of `text` column
# # For `completion` datasets only, uses the provided field instead of `text` column
# field:
# # Axolotl attempts to save the dataset as an arrow after packing the data together so

View File

@@ -55,7 +55,7 @@ Features:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.5.1
- PyTorch ≥2.6.0
### Installation

View File

@@ -276,6 +276,7 @@ website:
- docs/torchao.qmd
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- section: "Troubleshooting"
contents:

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),

View File

@@ -24,9 +24,9 @@ df_template = template_env.get_template(dockerfile)
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
"CUDA": os.environ.get("CUDA", "124"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),

View File

@@ -187,6 +187,7 @@ Instead of passing `tools` via the system prompt, an alternative method would be
"role": "assistant", // call the function via assistant
"tool_calls": [
{
"id": "...", // required only for mistral
"type": "function",
"function": {
"name": "...",
@@ -199,6 +200,7 @@ Instead of passing `tools` via the system prompt, an alternative method would be
},
{
"role": "tool",
"tool_call_id": "...", // required only for mistral
"name": "...",
"content": "..."
},

View File

@@ -34,9 +34,9 @@ Tags examples:
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
- `main-base-py3.11-cu124-2.5.1`
## Main
@@ -76,12 +76,12 @@ Tags examples:
- `main-py3.11-cu128-2.7.1`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0`
- `main-py3.11-cu124-2.5.1`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu124-2.5.1`
- `main-20250303-py3.11-cu126-2.6.0`
- `0.10.1`
## Cloud

View File

@@ -20,7 +20,7 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
> See the [example config](#example-config) file in addition to reading these instructions.
1. Set `adapter: qlora` in your axolotl config file.
2. Enable FSDP in your axolotl config, as [described here](https://github.com/axolotl-ai-cloud/axolotl?tab=readme-ov-file#fsdp).
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Example Config

View File

@@ -0,0 +1,29 @@
---
title: Gradient Checkpointing and Activation Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
models by reducing the memory footprint and improving computational efficiency.
### Enabling Gradient Checkpointing
```yaml
gradient_checkpointing: true
```
### Enabling Activation Offloading
```yaml
gradient_checkpointing: true # required for activation offloading
activation_offloading: true
```
Activation offloading variants:
The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
to overlap the communications and computations when offloading.
The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.

View File

@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.11
- PyTorch ≥2.5.1
- PyTorch ≥2.6.0
## Installation Methods {#sec-installation-methods}

View File

@@ -23,8 +23,6 @@ Axolotl supports several methods for multi-GPU training:
## DeepSpeed {#sec-deepspeed}
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
### Configuration {#sec-deepspeed-config}
Add to your YAML config:
@@ -32,7 +30,6 @@ Add to your YAML config:
```{.yaml}
deepspeed: deepspeed_configs/zero1.json
```
### Usage {#sec-deepspeed-usage}
```{.bash}
@@ -66,9 +63,75 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
## FSDP {#sec-fsdp}
::: {.callout-tip}
### Basic FSDP Configuration {#sec-fsdp-config}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
:::
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
also follow the config field mapping below to update field names.
#### Config mapping
FSDP1 | FSDP2
-------- | --------
fsdp_sharding_strategy | reshard_after_forward
fsdp_backward_prefetch_policy | **REMOVED**
fsdp_backward_prefetch | **REMOVED**
fsdp_forward_prefetch | **REMOVED**
fsdp_sync_module_states | **REMOVED**
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
For example, if you were using the following FSDP1 config:
```{.yaml}
fsdp_version: 1
fsdp_config:
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
```
You can migrate to the following FSDP2 config:
```{.yaml}
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
```
### FSDP1 (deprecated) {#sec-fsdp-config}
::: {.callout-note}
Using `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead.
:::
```{.yaml}
fsdp:
@@ -80,6 +143,7 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the

View File

@@ -40,13 +40,13 @@ use_cpu: false
Configure your model to use FSDP in the Axolotl yaml. For example:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_version: 2
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
offload_params: true
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
```
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.

View File

@@ -17,7 +17,6 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
## RLHF using Axolotl
@@ -275,15 +274,14 @@ rl: dpo
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
type:
field_prompt: "prompt"
field_system: "system"
field_chosen: "chosen"
field_rejected: "rejected"
prompt_format: "{prompt}"
chosen_format: "{chosen}"
rejected_format: "{rejected}"
```
The input format is a simple JSON input with customizable fields based on the above config.
@@ -476,14 +474,13 @@ rl: kto
datasets:
- path: ...
split: train
type: user_defined.default
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
type:
field_prompt: "prompt"
field_system: "system"
field_completion: "completion"
field_label: "label"
prompt_format: "{prompt}"
completion_format: "{completion}"
```
The input format is a simple JSON input with customizable fields based on the above config.

View File

@@ -0,0 +1,5 @@
# Archived Examples
This directory contains examples that are no longer maintained and may no longer be functional.
We keep them around for archival purposes in case they are useful to others.

View File

@@ -0,0 +1,70 @@
# Finetune Devstral with Axolotl
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505) and [Devstral-Small-2507](https://huggingface.co/mistralai/Devstral-Small-2507). `Devstral-Small-2507` is the latest version of the model and has [function calling](https://mistralai.github.io/mistral-common/usage/tools/) support.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of up to 128k tokens.
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:
```bash
axolotl train examples/devstral/devstral-small-qlora.yml
```
This config uses about 21GB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- Learn how to use function calling with Axolotl at [docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
## Limitations
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
In addition, we do not support overriding tokens yet.
## Related Resources
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
- [MistralAI Devstral 1.1 Blog](https://mistral.ai/news/devstral-2507)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
## Future Work
- Add parity to Preference Tuning, RL, Multi-modal, etc.
- Add parity to other tokenizer configs like overriding tokens.

View File

@@ -0,0 +1,64 @@
base_model: mistralai/Devstral-Small-2507
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Enable to use mistral-common tokenizer
tokenizer_use_mistral_common: true
load_in_8bit: false
load_in_4bit: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.05
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

7
examples/lfm2/README.md Normal file
View File

@@ -0,0 +1,7 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -0,0 +1,48 @@
base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens:
- "<|im_end|>"
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: true
tf32: true
gradient_checkpointing: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -18,16 +18,10 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Download the example config:
```bash
axolotl fetch examples
```
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/magistral/magistral-small-qlora.yaml
@@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
@@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
In addition, we do not support overriding tokens yet.
## Related Resources

View File

@@ -6,19 +6,19 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.10
liger-kernel==0.6.0
# END section
packaging==23.2
huggingface_hub==0.32.2
peft==0.15.2
transformers==4.53.1
huggingface_hub>=0.33.0
peft==0.16.0
transformers==4.53.2
tokenizers>=0.21.1
accelerate==1.8.1
datasets==3.6.0
datasets==4.0.0
deepspeed>=0.17.0
trl==0.18.2
trl==0.19.1
hf_xet==1.1.2
optimum==1.16.2
@@ -68,4 +68,4 @@ schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
mistral-common==1.6.3
mistral-common==1.7.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
)

View File

@@ -66,13 +66,16 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
if patch == 0:
_install_requires.append("xformers==0.0.30")
else:
_install_requires.append("xformers==0.0.31.post1")
extras_require_map["vllm"] = ["vllm>=0.9.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
@@ -118,7 +121,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.1",
"deepspeed==0.17.2",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.11.0.dev"
__version__ = "0.12.0.dev"

View File

@@ -1,5 +1,6 @@
"""CLI to run preprocessing of a dataset."""
import os
import warnings
from pathlib import Path
from typing import Union
@@ -95,6 +96,7 @@ def do_cli(
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)

View File

@@ -109,6 +109,13 @@ def ray_train_func(kwargs: dict):
# initialize accelerator before model instantiation
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
# Register plugins in Ray workers
if cfg.get("plugins"):
from axolotl.cli.config import plugin_set_cfg, prepare_plugins
prepare_plugins(cfg)
plugin_set_cfg(cfg)
kwargs["cfg"] = cfg
do_train(**kwargs)

View File

@@ -37,7 +37,6 @@ def do_vllm_serve(
Returns:
process_id: the process id of the started VLLM server
"""
patch_vllm_worker()
cfg = load_cfg(config)
model = cfg.base_model
@@ -47,6 +46,9 @@ def do_vllm_serve(
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
@@ -68,6 +70,7 @@ def do_vllm_serve(
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
host=host,
port=port,
gpu_memory_utilization=gpu_memory_utilization,

View File

@@ -112,13 +112,6 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
@@ -145,6 +138,14 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(GPUStatsCallback(cfg=self.cfg))
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -418,6 +419,9 @@ class TrainerBuilderBase(abc.ABC):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
@@ -426,8 +430,16 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing:
if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False
training_args_kwargs["activation_offloading"] = True
elif self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
@@ -501,10 +513,15 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.reward_model or self.cfg.rl:
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.fsdp_config or self.cfg.fsdp:
training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config
training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True
self._configure_reporting(training_args_kwargs)
self._configure_hub_parameters(training_args_kwargs)
self._configure_scheduler(training_args_kwargs)
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
self._configure_torch_compile(training_args_kwargs)
self._configure_accelerator_config(training_args_kwargs)
return training_args_kwargs, trainer_kwargs

View File

@@ -151,14 +151,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
@@ -318,11 +310,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.neftune_noise_alpha
)
if self.cfg.accelerator_config:
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm:

View File

@@ -208,7 +208,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
if self.cfg.fsdp:
if self.cfg.fsdp_config or self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
@@ -218,21 +218,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer.add_callback(callback)
return trainer
class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
HF Factory class for PPO Trainer
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):
# TODO: build PPOConfig
raise NotImplementedError("PPO trainer builder is not implemented yet.")

View File

@@ -14,5 +14,4 @@ from .trl import (
AxolotlORPOTrainer,
AxolotlPRMTrainer,
AxolotlRewardTrainer,
TRLPPOTrainer,
)

View File

@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
OptimizerMixin,
PackingMixin,
@@ -48,6 +49,7 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
ActivationOffloadingMixin,
Trainer,
):
"""Extend the base Trainer for axolotl helpers"""
@@ -75,18 +77,6 @@ class AxolotlTrainer(
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _create_multipack_sampler(
self, base_sampler: Sampler, dataset: Dataset
) -> MultipackBatchSampler:

View File

@@ -14,6 +14,7 @@ from axolotl.core.trainers.grpo.trainer import (
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__)
@@ -41,9 +42,18 @@ class GRPOStrategy:
return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore
vllm_cfg: VllmConfig = cfg.vllm # type: ignore
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
)
grpo_args_kwargs["vllm_tensor_parallel_size"] = (
vllm_cfg.tensor_parallel_size
)
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
if trl.vllm_server_timeout:

View File

@@ -59,42 +59,6 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""
@@ -252,7 +216,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
# Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params)

View File

@@ -3,6 +3,7 @@
# pylint: disable=unused-import
# flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin

View File

@@ -0,0 +1,37 @@
"""
Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
from trl.models.activation_offloading import get_act_offloading_ctx_manager
class ActivationOffloadingMixin(Trainer):
"""
Trainer mixin class for activation checkpointing w offloading
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.activation_offloading:
self.activation_offload_context = get_act_offloading_ctx_manager(
self.model, use_streams=True
)
else:
self.activation_offload_context = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self.activation_offload_context:
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)

View File

@@ -1,12 +1,9 @@
"""Module for TRL PPO trainer"""
"""Module for TRL RL trainers"""
import torch
from tqdm import tqdm
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PPOTrainer,
PRMTrainer,
RewardTrainer,
)
@@ -16,64 +13,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
class TRLPPOTrainer(PPOTrainer):
"""Wrapper for TRL PPO trainer to handle customizations"""
tag_names = ["axolotl", "ppo"]
def train(
self,
reward_pipe,
resume_from_checkpoint=None, # pylint: disable=unused-argument
):
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": 32,
}
sent_kwargs = {
"return_all_scores": True,
"function_to_apply": "none",
"batch_size": 16,
}
for _, batch in tqdm(enumerate(self.dataloader)):
query_tensors = batch["input_ids"]
# generate model response
response_tensors, ref_response_tensors = self.generate(
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs,
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
ref_rewards = [
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
]
batch["ref_rewards"] = ref_rewards
# Run PPO step
stats = self.step(query_tensors, response_tensors, rewards)
self.log_stats(
stats,
batch,
rewards,
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
)
class AxolotlORPOTrainer(
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
):

View File

@@ -217,6 +217,11 @@ class AxolotlTrainingMixins:
},
)
activation_offloading: bool | None = field(
default=None,
metadata={"help": "Use activation offloading with CUDA streams for training."},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
LOG.info(
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
)
num_proc = 1
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
```
## Usage

View File

@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
)

View File

@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
kd_alpha: 0.9
kd_temperature: 1.0
torch_compile: True # torch>=2.5.1, recommended to reduce vram
torch_compile: True # torch>=2.6.0, recommended to reduce vram
datasets:
- path: ...

View File

@@ -6,15 +6,21 @@ from typing import Optional, Union, Unpack
import torch
from transformers import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import LossKwargs
try:
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import LossKwargs
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
"""
placeholder kwargs for hf model classes
"""
class TransformersKwargs(FlashAttentionKwargs, LossKwargs):
"""
placeholder kwargs for hf model classes
"""
except ImportError:
from transformers.utils.generic import ( # type: ignore[no-redef]
TransformersKwargs,
)
def kldiv_forward_llama_like(
@@ -33,7 +39,7 @@ def kldiv_forward_llama_like(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
**kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (

View File

@@ -122,9 +122,9 @@ def load_lora(
rank = int(os.environ.get("LOCAL_RANK", 0))
if (
cfg.fsdp
cfg.fsdp_config
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_meta_for_peft(model)
@@ -152,9 +152,9 @@ def load_lora(
"Exception caught during model.print_trainable_parameters(): %s", exc
)
elif (
cfg.fsdp
cfg.fsdp_config
and cfg.adapter
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and cfg.fsdp_config.cpu_ram_efficient_loading
and rank != 0
):
setup_quantized_peft_meta_for_training(model)

View File

@@ -140,10 +140,15 @@ class ModelLoader:
"""Check if flash attention is installed."""
return find_spec("flash_attn") is not None
@cached_property
def qlora_fsdp(self):
@property
def is_fsdp_enabled(self):
"""Property that determines if FSDP is enabled."""
return self.cfg.fsdp_config is not None or self.cfg.fsdp is not None
@property
def is_qlora_and_fsdp_enabled(self):
"""Property that determines if FSDP with QLoRA is enabled."""
return self.cfg.fsdp and self.cfg.adapter == "qlora"
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
"""Load and prepare the model with all configurations and patches.
@@ -189,15 +194,25 @@ class ModelLoader:
# Handle PeftModel if needed
if (
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
and not self.qlora_fsdp
and not self.is_qlora_and_fsdp_enabled
):
self.model = self.model.merge_and_unload()
self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
self._log_memory_usage()
self._configure_embedding_dtypes()
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
ac_wrap_hf_model,
)
# ^^ importing this at the module level breaks plugins
ac_wrap_hf_model(self.model)
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""
@@ -251,22 +266,13 @@ class ModelLoader:
):
self.model.config.eos_token_id = self.tokenizer.eos_token_id
def _log_memory_usage(self):
"""Log device memory usage after model load."""
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
log_gpu_memory_usage(LOG, "after model load", self.model.device)
def _configure_embedding_dtypes(self):
"""Configure embedding module dtypes."""
# Get embedding modules
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
# Initial dtype conversion
if not self.cfg.fsdp:
if not self.is_fsdp_enabled:
# We don't run this during FSDP because this will leave mixed and bfloat16
# dtypes in the model which FSDP doesn't like
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
@@ -282,7 +288,7 @@ class ModelLoader:
self._set_z3_leaf_modules()
# Apply gradient checkpointing if needed
needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp
needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled
if self.cfg.adapter in ["lora", "qlora"]:
needs_fa2_dtype = True
if self.cfg.gradient_checkpointing:
@@ -298,10 +304,12 @@ class ModelLoader:
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
(
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
and not self.qlora_fsdp
and not self.is_qlora_and_fsdp_enabled
)
or (
# CCE requires embedding layers to be in fp16/bf16 for backward pass
self.cfg.cut_cross_entropy
)
# CCE requires embedding layers to be in fp16/bf16 for backward pass
or self.cfg.cut_cross_entropy
)
if should_convert:
@@ -357,7 +365,6 @@ class ModelLoader:
and not (self.cfg.rl and self.cfg.load_in_4bit)
and not skip_move_to_device
):
# TODO: validate this conditional
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
@@ -430,7 +437,17 @@ class ModelLoader:
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
if not is_deepspeed_zero3_enabled():
is_ds_zero3 = is_deepspeed_zero3_enabled()
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
if self.is_fsdp_enabled:
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
if self.is_qlora_and_fsdp_enabled:
self.model_kwargs["device_map"] = {
"": int(os.environ.get("LOCAL_RANK", 0))
}
# For other FSDP cases, don't set device_map at all
elif not is_ds_zero3:
self.model_kwargs["device_map"] = device_map
cur_device = get_device_type()
@@ -499,7 +516,7 @@ class ModelLoader:
"bnb_4bit_quant_storage": torch.bfloat16,
}
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
self.cfg.deepspeed or self.cfg.fsdp
self.cfg.deepspeed or self.is_fsdp_enabled
):
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
@@ -604,9 +621,21 @@ class ModelLoader:
def _build_model(self) -> bool:
"""Load model, with load strategy depending on config."""
skip_move_to_device = False
if self.is_fsdp_enabled:
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
skip_move_to_device = True
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
if (
"device_map" in self.model_kwargs
and not self.is_qlora_and_fsdp_enabled
):
del self.model_kwargs["device_map"]
elif self.is_qlora_and_fsdp_enabled:
skip_move_to_device = True
if (
self.qlora_fsdp
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
self.is_qlora_and_fsdp_enabled
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and (
self.cfg.model_config_type == "dbrx"
or self.cfg.qlora_sharded_model_loading
@@ -632,12 +661,6 @@ class ModelLoader:
and not self.cfg.trust_remote_code
and not self.cfg.gptq
):
# TODO: Do we need to open this up for all models?
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
@@ -691,33 +714,22 @@ class ModelLoader:
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
if self.cfg.gptq:
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
if (
self.cfg.fsdp
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
):
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"]
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
# Please don't remove underscore binding without reading the fn docstring.
_ = self._configure_zero3_memory_efficient_loading()
self.model = self.auto_model_loader.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
@@ -753,8 +765,8 @@ class ModelLoader:
skip_prepare_model_for_kbit_training = True
if (
self.qlora_fsdp
or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
self.is_qlora_and_fsdp_enabled
or (self.is_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading)
or is_deepspeed_zero3_enabled()
):
# Make sure everything is in the same dtype

Some files were not shown because too many files have changed in this diff Show More