Compare commits

..

49 Commits

Author SHA1 Message Date
Dan Saunders
78a039e1be add depr warning for preprocess --iterable 2025-08-22 16:02:30 +00:00
Dan Saunders
69f356163e fix 2025-08-22 16:02:30 +00:00
Dan Saunders
53bbca2591 bugfix for sample packing 2025-08-22 16:02:30 +00:00
Dan Saunders
49bd6ece4a remove unused 2025-08-22 16:02:30 +00:00
Dan Saunders
42b38a718a remove eval streaming (not HF supported) 2025-08-22 16:02:30 +00:00
Dan Saunders
4121bcbc33 fix kd test 2025-08-22 16:02:30 +00:00
Dan Saunders
0caa24eab0 comments 2025-08-22 16:02:30 +00:00
Dan Saunders
68bb70bbae fix test 2025-08-22 16:02:30 +00:00
Dan Saunders
5d8d7ef327 lint 2025-08-22 16:02:30 +00:00
Dan Saunders
7836da9ed9 remove unuse 2025-08-22 16:02:30 +00:00
Dan Saunders
7eba3795fe fixes 2025-08-22 16:02:30 +00:00
Dan Saunders
1b7b67d06e smoke test 2025-08-22 16:02:30 +00:00
Dan Saunders
0843dc678a separate out train and eval datasets streaming; cleanup 2025-08-22 16:02:30 +00:00
Dan Saunders
067158e24a nits 2025-08-22 16:02:30 +00:00
Dan Saunders
aa5a497a2c nits 2025-08-22 16:02:30 +00:00
Dan Saunders
2176962231 separate out train and eval dataset streaming 2025-08-22 16:02:30 +00:00
Dan Saunders
10335d5df9 add multidata strats 2025-08-22 16:02:30 +00:00
Dan Saunders
e4e8ffd40c nits 2025-08-22 16:02:30 +00:00
Dan Saunders
846aa41baa nits 2025-08-22 16:02:30 +00:00
Dan Saunders
7bb52d00bb progress on streaming 2025-08-22 16:02:30 +00:00
Dan Saunders
3b2dd05798 remove iterable CLI arg 2025-08-22 16:02:30 +00:00
Dan Saunders
b6431083be nit 2025-08-22 16:02:30 +00:00
Dan Saunders
16ff01df85 separate streaming and pretraining 2025-08-22 16:02:30 +00:00
Wing Lian
ab4d604a8f upgrade peft for 0.17.1 (#3094)
* upgrade peft to 0.17.1

* upgrade for transformers too
2025-08-22 07:26:30 -04:00
Wing Lian
0fa752e58b upgrade flash-attn to 2.8.3 for gpt-oss attn sink support (#3082) 2025-08-21 15:04:10 -04:00
Dan Saunders
08e517ea48 Update .coderabbit.yaml (#3091) [skip ci] 2025-08-20 22:14:13 -04:00
Wing Lian
07fd22f39b better handling of lora w bias with fsdp2 and handling of files when saving model checkpoint (#3090) 2025-08-20 15:17:48 -04:00
Wing Lian
06eaf6c448 misc fixes (#3085) 2025-08-20 08:52:26 -04:00
goggle
050210e637 fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080)
* refactor: improve output_dir handling in generate_config_files

* fix typo

* cli: harden sweep output_dir handling with base fallback

- Ensure sweep permutations always resolve a valid output_dir
- Default to ./model-out if neither permutation nor base config sets output_dir
- Append sweepXXXX suffix consistently for each permutation
- Prevent Path(None) TypeError and improve robustness of sweep config generation

* fix typo

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-19 20:25:20 -04:00
Wing Lian
05cedbfb1e add baseten info for gpt-oss recipe (#3078)
* add bsaeten info for gpt-oss recipe

* incorporate PR review
2025-08-19 13:30:37 -04:00
VED
c10eb811fa data_parallel_size in in VllmserveCliArgs (#3074)
* data_parallel_size in in VllmserveCliArgs

* moved to 43
2025-08-18 08:44:37 -04:00
VED
0eef385b1a [feat] truncation support with excess_length_strategy (#3068) [skip ci]
* feat:truncation support with excess_len

* pre-commit

* excess_length_strategy

* requested changes

* lint

* added handle_long_seq_in_dataset in sft

* comments improved
2025-08-18 08:39:13 -04:00
Wing Lian
ecbe8b2b61 [GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS (#3073)
* improve fsdp shard merging

* improve logging

* update information on merging and inferencing GPT-OSS

* cleanup readme

* automate cleanup of FSDP prefix

* import GRPO only if necessary

* only modify config.json on rank0

* merge final checkpoint at end of training

* prevent circular import

* Fix saving for sharded state dict

* devx, move merged to output dir

* move import back to top

* Fix stuck merge

* fix conditionals from pr feedback and add test
2025-08-15 21:25:01 -04:00
Wing Lian
130ef7c51a Various fixes for VLMs (#3063)
* fix to not use batch feature indexing

* more vlm fixes

* use AutoModelForImageTextToText

* add example yaml and need num2words for chat template

* improve handling of adding image tokens to conversation

* add lfm2-vl support

* update the lfm readme

* fix markdown and add rtol for loss checks

* feat: add smolvlm2 processing strat

* fix: check for causal-conv1d in lfm models

* feat: add docs for lfm2

* feat: add new models and tips to docs

* feat: add smolvlm2 docs and remove extra dep

* chore: update docs

* feat: add video instructions

* chore: cleanup

* chore: comments

* fix: typo

* feat: add usage stats

* chore: refactor

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-15 10:52:57 -04:00
salman
d1de6f5f3d Add option to skip slow tests in PRs (#3060) [skip ci]
* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* stop running multigpu [skip-e2e]

* should work now [skip-e2e]

* reverting [skip-e2e]

* testing [skip-e2e]

* debug [skip-e2e]

* debug [skip-e2e]

* round 2[skip-e2e]

* removing debug [skip-e2e]

* support skipping whole PR [skip-e2e]

* use script for e2e skip [skip-e2e]

* contributing [skip-e2e]

* contributing [skip-e2e]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-13 22:57:51 -04:00
Wing Lian
48b7ae1677 use updated patch releasE (#3066) 2025-08-13 21:23:05 -04:00
NanoCode012
506e3a3907 fix: fsdp_config validation being None (#3061) [skip ci]
* fix: fsdp_config validation being None

* fix: handling

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-08-13 21:21:50 -04:00
Wing Lian
09145de8fa upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064)
* upgrade transformers==4.55.1

* also upgrade bnb

* remove bnb params4bit patch (upstreamed)

* use latest causal-conv1d

* fix patching ring-flash-attn with now missing imports

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-08-13 19:41:07 -04:00
Wing Lian
e0a2523a3b Workaround to unblock docs build in main (#3055)
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-08-13 11:39:39 +01:00
Wing Lian
3d45620008 remove prepare-from-posids patch (#3052) [skip ci] 2025-08-11 09:34:41 -04:00
github-actions[bot]
ce20e838b5 chore: update pre-commit hooks (#3050) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-08-11 09:32:21 -04:00
Wing Lian
d4d84d48af fix ray train and add fsdp2 smoke test for ray trainer (#3053)
* add fsdp2 smokle test for ray trainer

* fix raytrain with fsdp2
2025-08-11 09:31:54 -04:00
Wing Lian
9b12c05660 use exec instead of subprocess to make ctrl+c nicer for cli (#3044)
* use exec instead of subprocess to make ctrl+c nicer for cli

* change var name to use_exec

* simplify to bool

* flush std*

* patch subprocess as mock in test

* fix tests

* more test fixes
2025-08-10 20:22:20 -04:00
Wing Lian
686933194e fix vllm tagging and add cloud images w/o tmux (#3049) [skip ci] 2025-08-10 20:21:56 -04:00
Wing Lian
d12b461d19 follow up fix for plugin registration (#3054) [skip ci] 2025-08-10 20:21:38 -04:00
Wing Lian
d6b81b3683 update training args check for new defaults (#3051) [skip ci]
* update training args check for new defaults

* skip check for now
2025-08-10 11:26:22 -04:00
Wing Lian
05f1b4b2e8 run monkeypatch tests in seperate runner (#3047) 2025-08-09 14:34:07 -04:00
Wing Lian
7cfc80ec77 set dev version (#3045) [skip ci] 2025-08-08 13:56:53 -04:00
salman
0da6a95efa Add citation.tff (#3043) [skip ci] 2025-08-08 16:18:42 +01:00
70 changed files with 1843 additions and 806 deletions

View File

@@ -12,5 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
chat:
auto_reply: true

View File

@@ -57,6 +57,13 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
#### Skipping CI Checks
You can skip certain CI checks by including specific keywords in your commit messages:
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
## Style Guidelines
### Code Style

View File

@@ -98,6 +98,12 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
@@ -151,6 +157,18 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -105,7 +105,8 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
@@ -179,21 +180,52 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
steps:
- uses: actions/github-script@v7
id: compute
with:
script: |
const token = /\[skip-e2e\]/i;
let msg = '';
if (context.eventName === 'push') {
msg = context.payload.head_commit?.message || '';
} else if (context.eventName === 'pull_request') {
const { owner, repo } = context.repo;
const prNumber = context.payload.pull_request.number;
const commits = await github.paginate(
github.rest.pulls.listCommits,
{ owner, repo, pull_number: prNumber, per_page: 100 }
);
msg = commits.at(-1)?.commit?.message || '';
}
const title = context.payload.pull_request?.title || '';
const body = context.payload.pull_request?.body || '';
const skip = token.test(msg) || token.test(title) || token.test(body);
core.setOutput('skip', String(skip));
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist]
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
strategy:
fail-fast: false
@@ -239,13 +271,16 @@ jobs:
modal run cicd.e2e_tests
docker-e2e-tests:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
strategy:
fail-fast: false

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.7
rev: v3.3.8
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

10
CITATION.cff Normal file
View File

@@ -0,0 +1,10 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
url: "https://axolotl.ai/"
license: Apache-2.0
date-released: "2023-05-30"

View File

@@ -149,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📝 Citing Axolotl
If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},
year = {2023}
}
```
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

10
TODO.md
View File

@@ -1,10 +0,0 @@
# todo list
- [] Validation of parameters for combinations that won't work
## things that are known not to work
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
- adamw_bnb_8bit doesn't play well with FSDP offload

View File

@@ -37,7 +37,7 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge

View File

@@ -13,10 +13,13 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
## Usage
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal
chat_template: # see in next section
chat_template: # see in next section if specified
# example dataset
datasets:
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken
```
### Voxtral {#sec-voxtral}
::: {.callout-tip}
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
:::
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
```
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
:::
```yaml
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
```
### LFM2-VL {#sec-lfm2-vl}
::: {.callout-warning}
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
:::
```yaml
base_model: LiquidAI/LFM2-VL-450M
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::
### Video
::: {.callout-warning}
This is not well tested at the moment. We welcome contributors!
:::
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
- `"path": "/path/to/video.mp4"`
- `"url": "https://example.com/video.mp4"`
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
### Example
Here is an example of a multi-modal dataset:

View File

@@ -0,0 +1,58 @@
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
**LFM2**
```bash
# FFT SFT (1x48GB @ 25GiB)
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
```
**LFM2-VL**
```bash
# LoRA SFT (1x48GB @ 2.7GiB)
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
- **Dataset Formats**:
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens:
- "<|im_end|>"
datasets:

View File

@@ -0,0 +1,58 @@
base_model: LiquidAI/LFM2-VL-450M
trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -33,13 +33,64 @@ Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
### Training 120B
On 8xH100s
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
```bash
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
```bash
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
```
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
weights to `{output_dir}/merged`.
```bash
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
```
### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
```bash
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
```
### Tool use
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.

View File

@@ -20,6 +20,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
sequence_len: 4096
sample_packing: true
@@ -43,7 +44,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -1,7 +0,0 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -0,0 +1,49 @@
# Finetune SmolVLM2 with Axolotl
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
pip3 install num2words==0.5.14
```
3. Run the finetuning example:
```bash
# LoRA SFT (1x48GB @ 6.8GiB)
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
```
## TIPS
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,56 @@
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
trust_remote_code: true
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.1
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1
@@ -13,8 +13,8 @@ liger-kernel==0.6.1
packaging==23.2
huggingface_hub>=0.33.0
peft==0.17.0
transformers==4.55.0
peft>=0.17.0
transformers==4.55.3
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.2"],
"flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [
"flash-attn==2.8.2",
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.12.0"
__version__ = "0.13.0.dev"

View File

@@ -14,9 +14,13 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
default=False,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
"help": (
"[DEPRECATED] No longer supported. For streaming datasets, use "
"'axolotl train' and set 'streaming: true' in your YAML config, or "
"pass --streaming instead in the CLI."
)
},
)
@@ -40,6 +44,12 @@ class VllmServeCliArgs:
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.6.0"
docker_tag = "main-py3.11-cu126-2.7.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
return f"H100:{count}"
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":

View File

@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
# now that we have the finalized cfg, register the plugins individually
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def load_cfg(

View File

@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer

View File

@@ -123,9 +123,10 @@ def train(
_launcher = None if kwargs.get("use_ray") else launcher
# Process each configuration
for cfg_file in generate_config_files(config, sweep):
for cfg_file, is_group in generate_config_files(config, sweep):
try:
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
use_exec = is_group is not True
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:

View File

@@ -10,6 +10,7 @@ import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
from axolotl.utils.train import determine_last_checkpoint
LOG = get_logger(__name__)
@@ -143,7 +145,6 @@ def merge_fsdp_weights(
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
@@ -180,7 +181,6 @@ def merge_fsdp_weights(
if remove_checkpoint_dir:
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
shutil.rmtree(checkpoint_dir_)
state.wait_for_everyone()
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
@@ -195,11 +195,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
if checkpoint_dir:
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
raise ValueError(
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
)
output_path = str(Path(parsed_cfg.output_dir) / "merged")
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
output_path=output_path,
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()
LOG.info(
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
main_process_only=True,
)
LOG.info(
"Merged weights are only the safetensors and doesn't include the model configuration "
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
main_process_only=True,
)
if __name__ == "__main__":

View File

@@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
)
return
@@ -97,7 +107,8 @@ def do_cli(
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs)
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -3,11 +3,12 @@
import random
from copy import deepcopy
from itertools import product
from typing import Any
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]:
) -> list[dict[str, Any]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.

View File

@@ -2,7 +2,9 @@
import os
import subprocess # nosec
import sys
import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal
import yaml
@@ -64,10 +66,18 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
return cmd
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
"""Generate list of configuration files to process."""
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
"""
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
whether this is a group of configurations (i.e., a sweep).
Args:
config: Base configuration file
sweep: Sweep configuration file
"""
if not sweep:
yield config
yield config, False
return
# Load sweep and base configurations
@@ -78,7 +88,13 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
for permutation in permutations:
is_group = len(permutations) > 1
base_output_dir = base_config.get("output_dir", "./model-out")
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",
@@ -88,7 +104,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
)
yaml.dump(permutation, temp_file)
temp_file.close()
yield temp_file.name
yield temp_file.name, is_group
def launch_training(
@@ -97,6 +113,7 @@ def launch_training(
cloud: str | None,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training with the given configuration."""
launcher_args = launcher_args or []
@@ -105,11 +122,14 @@ def launch_training(
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
elif launcher:
if launcher == "accelerate":
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "torchrun":
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
elif launcher == "python":
_launch_python_training(cfg_file, kwargs)
elif launcher is None:
# handle ray train launch
_launch_python_training(cfg_file, kwargs)
def _launch_cloud_training(
@@ -136,7 +156,10 @@ def _launch_cloud_training(
def _launch_accelerate_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via accelerate launcher."""
launcher_args = launcher_args or []
@@ -161,11 +184,20 @@ def _launch_accelerate_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_torchrun_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training via torchrun launcher."""
launcher_args = launcher_args or []
@@ -178,7 +210,13 @@ def _launch_torchrun_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:

View File

@@ -55,13 +55,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -5,7 +5,6 @@
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .trl import (
AxolotlCPOTrainer,

View File

@@ -1,18 +1,19 @@
"""Module containing Dataset functionality"""
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
from typing import Any
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -42,10 +43,13 @@ class TokenizedPromptDataset(Dataset):
**kwargs,
)
def process(self, dataset):
features = dataset.features.keys()
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
"""Apply filtering and tokenization."""
features = None
if not isinstance(dataset, IterableDataset):
features = dataset.features.keys()
map_kwargs = {}
map_kwargs: dict[str, Any] = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 1_000
@@ -54,18 +58,28 @@ class TokenizedPromptDataset(Dataset):
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
if not isinstance(dataset, IterableDataset):
filter_kwargs["num_proc"] = self.process_count
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=self.process_count,
desc="Strategy Filtering Rows",
**filter_kwargs,
)
map_kwargs = {
**map_kwargs,
"desc": "Tokenizing Prompts",
}
# Only add remove_columns for regular datasets
if not isinstance(dataset, IterableDataset):
map_kwargs["remove_columns"] = features
map_kwargs["num_proc"] = self.process_count
map_kwargs["keep_in_memory"] = self.keep_in_memory
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=self.process_count,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",
**map_kwargs,
)
@@ -79,140 +93,16 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = list(dataset.features.keys())
# Map the dataset and remove original columns
# For IterableDataset, features might be None until first iteration
remove_columns = None
if dataset.features is not None:
remove_columns = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
remove_columns=remove_columns,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -76,8 +76,8 @@ class BasePlugin:
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration.
def register(self, cfg: dict): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration as an unparsed dict.
Args:
cfg: The configuration for the plugin.

View File

@@ -1,26 +1,13 @@
"""Shared constants for axolotl.loaders module"""
from transformers import (
Gemma3ForConditionalGeneration,
Gemma3nForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
from transformers import AutoModelForImageTextToText
from transformers.models.auto.modeling_auto import (
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
)
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llama4": Llama4ForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
"gemma3n": Gemma3nForConditionalGeneration,
}
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
try:
from transformers import VoxtralForConditionalGeneration

View File

@@ -25,6 +25,7 @@ from peft import (
from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
@@ -212,6 +213,7 @@ class ModelLoader:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
self._check_model_requirements()
def _apply_post_model_load_setup(self):
"""Configure the model after it has been loaded."""
@@ -432,6 +434,8 @@ class ModelLoader:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
def _set_device_map_config(self):
"""Setup `device_map` according to config"""
@@ -628,6 +632,16 @@ class ModelLoader:
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
def _check_model_requirements(self):
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
from transformers.utils.import_utils import is_causal_conv1d_available
if is_causal_conv1d_available():
raise ImportError(
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
)
def _configure_zero3_memory_efficient_loading(
self,
) -> HfTrainerDeepSpeedConfig | None:

View File

@@ -73,9 +73,6 @@ class PatchManager:
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
@@ -87,7 +84,6 @@ class PatchManager:
and self.cfg.fsdp_version == 2
)
patch_prepare_from_posids()
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()
@@ -289,12 +285,10 @@ class PatchManager:
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()

View File

@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if module.base_layer.bias is not None:
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(

View File

@@ -9,73 +9,12 @@ Params4bit parameters.
import importlib
import inspect
import torch
from torch.nn import Parameter
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patched_torch_function(cls, func, types, args=(), kwargs=None):
"""
Patched version of Params4bit.__torch_function__ for preserving Params4bit
class identity and attributes.
"""
if kwargs is None:
kwargs = {}
if func in [torch.chunk, torch.split]:
tensor = args[0]
result = Parameter.__torch_function__(func, types, args, kwargs)
if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)
return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
return Parameter.__torch_function__(func, types, args, kwargs)
# pylint: disable=protected-access
def apply_bnb_torch_function_patch():
"""
Patch Params4bit.__torch_function__ using Axolotl-style approach.
Returns:
True if patching succeeded, False otherwise.
"""
from bitsandbytes.nn.modules import Params4bit
Params4bit.__torch_function__ = classmethod(patched_torch_function)
LOG.info("Successfully patched Params4bit.__torch_function__")
# pylint: disable=protected-access
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""

View File

@@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

View File

@@ -15,12 +15,15 @@ import torch
import torch.distributed as dist
from torch.distributed import DeviceMesh
try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger

View File

@@ -1,87 +0,0 @@
"""
Monkey patch to fix transformers.modeling_flash_attention_utils.
see https://github.com/huggingface/transformers/pull/39653/files
"""
import sys
import torch
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
def patch_prepare_from_posids():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
_prepare_from_posids
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_prepare_from_posids",
_prepare_from_posids,
)

View File

@@ -6,7 +6,7 @@ from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor, zeros_like
from transformers import ProcessorMixin, VoxtralProcessor
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
@@ -138,7 +138,7 @@ class ProcessingStrategy:
image_key = key
break
# if the image key exists, add the image to the first message
# if the image key exists, add the image to the first user message
if image_key is not None and processed_example[image_key] is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
@@ -179,26 +179,34 @@ class ProcessingStrategy:
# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
msg_ind_to_add = None
ind_to_add = None
first_user_idx = None
for i, content in enumerate(
processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
for msg_idx, msg_content in enumerate(processed_example["messages"]):
if first_user_idx is None and msg_content["role"] == "user":
first_user_idx = msg_idx
for i, content in enumerate(
processed_example["messages"][msg_idx]["content"]
):
ind_to_add = i
break
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
msg_ind_to_add = msg_idx
ind_to_add = i
break
# If an image type is found, add the image to that index
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
if ind_to_add is not None and msg_ind_to_add is not None:
processed_example["messages"][msg_ind_to_add]["content"][
ind_to_add
]["image"] = image_value
else:
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
# if no image type is found, add it to end of the first user message
if first_user_idx is None:
first_user_idx = 0
processed_example["messages"][first_user_idx]["content"].append(
{
"type": "image",
"image": image_value,
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
return labels
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for SmolVLM2"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<image>" # nosec
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index(self.image_token)
]
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -402,32 +428,43 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,
"image_size": image_size,
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if chat_template_type == "gemma3n":
return Gemma3nProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llama4",
"llava",
"mistral_v7_tekken",
"pixtral",
]:
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
if isinstance(processor, VoxtralProcessor):
return VoxtralProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
**processing_kwargs,
)
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
if isinstance(processor, SmolVLMProcessor):
return SmolVLM2ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(
**processing_kwargs,
)

View File

@@ -129,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
images=images,
return_tensors="pt",
)
if hasattr(batch, "to_dict"):
batch = batch.to_dict()
else:
batch = dict(batch)
# workaround since processor works in batches instead of single examples
out = {}
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
if hasattr(val, "tolist"):
out[k] = (
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
)
else:
batch[k] = val.squeeze().tolist()
return batch
out[k] = val
return out
return self.tokenizer.apply_chat_template(
conversation,
@@ -433,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
tokenized_prompt = dict(tokenized_res)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
if isinstance(prompt_ids, dict):
user_prompt_len = len(prompt_ids["input_ids"])
else:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids

View File

@@ -4,11 +4,14 @@ from __future__ import annotations
import importlib
import inspect
import json
import os
import shutil
import signal
import sys
import typing
import weakref
from collections import OrderedDict
from contextlib import ExitStack
from pathlib import Path
from typing import Any, Dict
@@ -38,6 +41,7 @@ from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer
try:
@@ -46,7 +50,7 @@ except ImportError:
BetterTransformer = None
if typing.TYPE_CHECKING:
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
@@ -124,32 +128,6 @@ def setup_reference_model(
return model_ref
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
@@ -275,19 +253,60 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
if trainer.is_fsdp_enabled or cfg.fsdp_config:
if ( # pylint: disable=too-many-nested-blocks
trainer.is_fsdp_enabled or cfg.fsdp_config
):
if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type
else:
state_dict_type = cfg.fsdp_config.state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
trainer.save_model(cfg.output_dir)
trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT
if state_dict_type == "SHARDED_STATE_DICT":
LOG.info(
"The final model was saved with a sharded state dict. Please ensure you merge "
"the sharded weights with `merge-sharded-fsdp-weights`."
)
checkpoint_dir = determine_last_checkpoint(cfg, update=False)
if (
not (Path(cfg.output_dir) / "model.safetensors.index.json").exists()
and checkpoint_dir
):
# import here to prevent circular import
from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
merged_path = str(Path(cfg.output_dir) / "merged")
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=merged_path,
safe_serialization=True,
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
if trainer.accelerator.is_main_process:
with open(
Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
) as config_file_io:
# read the model config as an OrderedDict
config = json.load(config_file_io, object_pairs_hook=OrderedDict)
config["architectures"] = [
name.lstrip("FSDP") for name in config["architectures"]
]
# write the updated model config back
with open(
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
) as config_file_io:
json.dump(config, config_file_io, indent=2)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
@@ -564,7 +583,7 @@ def train(
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_resume_checkpoint(cfg)
resume_from_checkpoint = determine_last_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# clear cache

View File

@@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
@@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
examples = self.processing_strategy(examples)
# Initialize batch
batch: dict[str, Any] = {}
messages = [ex["messages"] for ex in examples]
# Process each example
for example in examples:
# Apply chat template to process the example
# This method requires transformers>=4.49.0
result = self.processing_strategy.processor.apply_chat_template(
example["messages"],
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
# TODO: Check if need handling for len(input_ids) > sequence_len
# Add the processed tensors to our batch
for key in result.keys():
if key not in batch:
batch[key] = []
batch[key].append(result[key].squeeze(0))
# Pad sequences to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
batch["input_ids"],
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
batch = self.processing_strategy.processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=True,
return_tensors="pt",
padding=True,
return_dict=True,
chat_template=self.processing_strategy.chat_template,
)
attention_mask = torch.nn.utils.rnn.pad_sequence(
batch["attention_mask"], batch_first=True, padding_value=0
)
# Create the final batch
final_batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
for key, val in batch.items():
if key in ["input_ids", "attention_mask"]:
continue
if key in ["token_type_ids", "cross_attention_mask"]:
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
val, batch_first=True, padding_value=0
)
else:
final_batch[key] = torch.stack(val)
# Process the labels
final_batch["labels"] = self.processing_strategy.process_labels(
final_batch["input_ids"]
)
batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
return final_batch
return batch

View File

@@ -9,6 +9,7 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
@@ -28,7 +29,7 @@ from axolotl.utils.data.shared import (
)
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
handle_long_seq_in_dataset,
retry_on_request_exceptions,
)
from axolotl.utils.data.wrappers import get_dataset_wrapper
@@ -43,12 +44,24 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__)
def _is_streaming_enabled(cfg: DictDefault) -> bool:
"""Check if streaming is enabled for a specific split."""
streaming = cfg.get("streaming")
if streaming is True:
return True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = cfg.get("pretraining_dataset") is not None
streaming = has_pretraining and streaming is None
return streaming
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_datasets(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare training and evaluation datasets based on configuration.
@@ -56,23 +69,19 @@ def prepare_datasets(
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: Tokenizer to use for processing text.
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
)
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
return _prepare_pretraining_dataset(cfg, tokenizer, processor)
return _prepare_standard_dataset(cfg, tokenizer, processor)
def _prepare_standard_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
@@ -83,7 +92,6 @@ def _prepare_standard_dataset(
cfg,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Overwrite eval_dataset if test data exists
@@ -93,7 +101,6 @@ def _prepare_standard_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
return train_dataset, eval_dataset, prompters
@@ -109,7 +116,12 @@ def _prepare_standard_dataset(
return train_dataset, eval_dataset, -1, prompters
# Validate sample packing configuration for evaluation
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
if (
eval_dataset
and cfg.sample_packing
and cfg.eval_sample_packing is not False
and not isinstance(eval_dataset, IterableDataset)
):
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
@@ -117,13 +129,17 @@ def _prepare_standard_dataset(
"You should set `eval_sample_packing: False` in your config."
)
# Calculate total number of training steps
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
# Set total_num_steps for training
if isinstance(train_dataset, IterableDataset):
total_num_steps = cfg.max_steps
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
return train_dataset, eval_dataset, total_num_steps, prompters
@@ -132,7 +148,6 @@ def _prepare_pretraining_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""
Prepare dataset for pretraining mode.
@@ -153,7 +168,6 @@ def _prepare_pretraining_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_exact_deduplication:
@@ -256,7 +270,6 @@ def _load_tokenized_prepared_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
"""Load or create tokenized and prepared datasets for training or testing.
@@ -265,39 +278,51 @@ def _load_tokenized_prepared_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (dataset, prompters list).
"""
# Select correct dataset configuration based on split
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw datasets
prompters: list[Prompter | None] = []
if dataset is None:
use_streaming = False
if split == "train":
use_streaming = _is_streaming_enabled(cfg)
if use_streaming:
# For streaming datasets, skip caching and load raw datasets directly
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
preprocess_iterable,
)
else:
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw
# datasets
if dataset is None:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
)
return dataset, prompters
@@ -306,9 +331,8 @@ def _load_raw_datasets(
cfg: DictDefault,
datasets_configs: list,
tokenizer: PreTrainedTokenizer,
split: str,
split: Literal["train", "test"],
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset, list[Prompter | None]]:
"""Load, process, merge, and save raw datasets."""
LOG.info("Loading raw datasets...", main_process_only=False)
@@ -329,7 +353,6 @@ def _load_raw_datasets(
split=split,
seed=cfg.seed,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
@@ -339,17 +362,18 @@ def _load_raw_datasets(
if not cfg.skip_prepare_dataset:
if split == "test" and cfg.eval_sequence_len:
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Only save regular datasets to disk, not streaming datasets
if not isinstance(dataset, IterableDataset):
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters
@@ -358,22 +382,22 @@ def _load_and_process_single_dataset(
dataset_config: DictDefault,
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
split: str,
split: Literal["train", "test"],
seed: int,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
)
use_streaming = False
if split == "train":
use_streaming = _is_streaming_enabled(cfg)
# Parse dataset type
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, use_streaming
)
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, DatasetDict):
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -418,11 +442,13 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]:
def _handle_train_dataset_split(
dataset: Dataset, cfg: DictDefault
) -> tuple[Dataset, Dataset | None]:
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
"""Handle processing for train split, including validation set creation."""
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
int(cfg.val_set_size)
if cfg.val_set_size and cfg.val_set_size > 1
else float(cfg.val_set_size or 0.0)
)
if val_set_size:
@@ -433,27 +459,33 @@ def _handle_train_dataset_split(
return train_dataset, eval_dataset
# No validation split - apply deduplication if needed and return as train dataset
if cfg.dataset_exact_deduplication:
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
train_dataset = dataset
return train_dataset, None
def _handle_test_dataset_split(
dataset: Dataset, cfg: DictDefault
) -> tuple[None, Dataset | None]:
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[None, Dataset | IterableDataset | None]:
"""Handle processing for test split."""
if cfg.dataset_exact_deduplication:
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
eval_dataset = dataset
return None, eval_dataset
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
def _apply_dataset_sharding(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> Dataset | IterableDataset:
"""Apply dataset sharding if configured.
Args:
@@ -479,7 +511,6 @@ def _load_and_prepare_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:
"""Load and prepare datasets with optional validation split and sharding.
@@ -488,7 +519,6 @@ def _load_and_prepare_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, prompters).
@@ -499,7 +529,6 @@ def _load_and_prepare_datasets(
cfg,
split=split,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Apply dataset sharding if configured using shared function

View File

@@ -13,6 +13,7 @@ from datasets import (
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
interleave_datasets,
load_dataset,
load_from_disk,
)
@@ -524,7 +525,9 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -537,23 +540,23 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
if len(datasets) == 1:
ds = datasets[0]
# Do not shuffle if curriculum sampling is enabled or
# shuffle_merged_datasets is disabled
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
if (
cfg.curriculum_sampling
or not cfg.shuffle_merged_datasets
or isinstance(ds, IterableDataset)
):
return ds
return ds.shuffle(seed=cfg.seed)
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
if cfg.shuffle_before_merging_datasets and all(
isinstance(ds, Dataset) for ds in datasets
):
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
merged_dataset = _merge_datasets_with_strategy(datasets, cfg)
if cfg.shuffle_merged_datasets:
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
@@ -562,6 +565,45 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")
if isinstance(merged_dataset, IterableDataset):
LOG.debug("Skipping shuffle for streaming datasets.")
else:
LOG.debug("Not shuffling merged datasets.")
return merged_dataset
def _merge_datasets_with_strategy(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""
Merge datasets using the configured mixing strategy. Works with streaming and non-
streaming datasets.
Args:
datasets: List of datasets to merge.
cfg: Configuration object containing mixing settings.
Returns:
Merged dataset (Dataset or IterableDataset depending on inputs).
"""
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
weights = cfg.get("mixing_weights", None)
LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
if strategy == "concatenate":
if not all(isinstance(ds, Dataset) for ds in datasets):
raise ValueError(
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
"or 'random' instead."
)
return concatenate_datasets(datasets)
if strategy == "round_robin":
return interleave_datasets(datasets, seed=cfg.seed)
if strategy == "weighted":
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
if strategy == "random":
equal_weights = [1.0 / len(datasets)] * len(datasets)
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
raise ValueError(f"Unknown dataset mixing strategy: {strategy}")

View File

@@ -148,7 +148,36 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset
def drop_long_seq_in_dataset(
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Truncate samples whose sequence length is too long (> sequence_len)
or drop those too short (< min_sequence_len).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
results = []
# Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids):
length = len(seq)
if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
results.append(True)
else:
results.append(True)
return results
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
@@ -161,11 +190,15 @@ def drop_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
if hasattr(dataset, "column_names") and dataset.column_names:
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This "
"is expected for reward modeling."
)
return dataset
elif isinstance(dataset, IterableDataset):
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
return dataset
drop_long = functools.partial(
@@ -192,8 +225,21 @@ def drop_long_seq_in_dataset(
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
else:
process_fn = drop_long
dataset = dataset.filter(
drop_long,
process_fn,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
@@ -201,6 +247,11 @@ def drop_long_seq_in_dataset(
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
action = (
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
return dataset

View File

@@ -414,6 +414,12 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field(
default=None,
json_schema_extra={
@@ -926,9 +932,27 @@ class AxolotlInputConfig(
fix_untrained_tokens: int | list[int] | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use streaming datasets (IterableDataset) for training datasets. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
},
)
dataset_mixing_strategy: str | None = Field(
default="round_robin",
json_schema_extra={
"description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
},
)
mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'."
},
)
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None
total_num_tokens: int | None = Field(
default=None,

View File

@@ -161,7 +161,12 @@ class HyperparametersConfig(BaseModel):
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
num_epochs: float = Field(default=1.0)
num_epochs: float = Field(
default=1.0,
json_schema_extra={
"description": "Number of iterations over dataset for training"
},
)
@field_validator("batch_size")
@classmethod

View File

@@ -3,6 +3,8 @@
# pylint: disable=too-many-boolean-expressions
import json
import os
import sys
import tempfile
from pathlib import Path
@@ -191,6 +193,7 @@ class AttentionValidationMixin:
return data
# pylint: disable=too-many-public-methods
class TrainingValidationMixin:
"""Validation methods related to training configuration."""
@@ -369,10 +372,10 @@ class TrainingValidationMixin:
"see speed improvements. Please consider setting `torch_compile: "
"true` in your config."
)
fsdp_config = data.get("fsdp_config") or {}
if data.get("fp8") and (
data.get("fsdp_config", {}).get("activation_checkpointing", False) is True
or data.get("fsdp_config", {}).get("fsdp_activation_checkpointing", False)
is True
fsdp_config.get("activation_checkpointing", False) is True
or fsdp_config.get("fsdp_activation_checkpointing", False) is True
):
LOG.warning(
"FP8 + FSDP2 + activation checkpointing may be slower than BF16 "
@@ -507,11 +510,58 @@ class TrainingValidationMixin:
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
# due to trying to count the number of tokens total in the dataset
raise ValueError(
"pretraining_dataset and include_tokens_per_second cannot be used together."
"pretraining_dataset and include_tokens_per_second cannot be used "
"together."
)
return data
@model_validator(mode="before")
@classmethod
def check_max_steps_num_epochs_conflict(cls, data):
"""Handle max_steps and num_epochs configuration and auto-set defaults."""
max_steps = data.get("max_steps")
num_epochs = data.get("num_epochs")
# Auto-set num_epochs to 1 if neither max_steps nor num_epochs are set
if max_steps is None and num_epochs is None:
data["num_epochs"] = 1.0
return data
@model_validator(mode="before")
@classmethod
def check_saves_per_epoch_conflicts(cls, data):
"""Ensure saves_per_epoch is compatible with training configuration."""
saves_per_epoch = data.get("saves_per_epoch")
num_epochs = data.get("num_epochs")
if saves_per_epoch is not None:
# Check if saves_per_epoch is set but num_epochs is unset
if num_epochs is None:
raise ValueError(
"saves_per_epoch requires num_epochs to be set to calculate save "
"intervals."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals_per_epoch_conflicts(cls, data):
"""Ensure evals_per_epoch is compatible with training configuration."""
evals_per_epoch = data.get("evals_per_epoch")
num_epochs = data.get("num_epochs")
if evals_per_epoch is not None:
if num_epochs is None:
raise ValueError(
"evals_per_epoch requires num_epochs to be set to calculate "
"evaluation intervals."
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -817,13 +867,13 @@ class OptimizationValidationMixin:
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
if data.get("fsdp_config"):
if data.get("fsdp_config", {}).get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data
@model_validator(mode="before")
@@ -1077,6 +1127,27 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_streaming_split_batches_accelerate(cls, data):
# Check if streaming is enabled for training
streaming = data.get("streaming", False)
# If streaming is enabled, configure accelerator
if streaming:
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""
@@ -1151,10 +1222,8 @@ class ModelCompatibilityValidationMixin:
@classmethod
def check_gpt_oss_fsdp_loading(cls, data):
if data.get("model_quantization_config", "") == "Mxfp4Config":
if (
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
is True
):
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config.get("cpu_ram_efficient_loading", False) is True:
raise ValueError(
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
)
@@ -1251,10 +1320,26 @@ class ComplexValidationMixin:
try:
import transformers.modeling_flash_attention_utils
from transformers.utils import is_flash_attn_greater_or_equal
# pylint: disable=protected-access
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
transformers.modeling_flash_attention_utils._flash_supports_window
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window_size",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"is_flash_attn_greater_or_equal",
is_flash_attn_greater_or_equal,
)
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
@@ -1321,6 +1406,128 @@ class GRPOVllmValidationMixin:
return self
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
def _is_streaming_enabled(self) -> bool:
"""Check if streaming is enabled."""
# Fall back to main streaming setting
streaming = getattr(self, "streaming", None)
if streaming is True:
return True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming = has_pretraining and streaming is None
return streaming
@model_validator(mode="after")
def check_streaming_requires_max_steps(self):
"""Ensure max_steps is set when using streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
max_steps = getattr(self, "max_steps", None)
if not max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
return self
@model_validator(mode="after")
def check_streaming_validation_splits_conflict(self):
"""Ensure validation splits are not used with streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
val_set_size = getattr(self, "val_set_size", 0.0)
if val_set_size and val_set_size > 0:
raise ValueError(
"Validation splits not supported for streaming datasets, please "
"use test_datasets: ... instead"
)
return self
@model_validator(mode="after")
def check_streaming_preprocessing_conflict(self):
"""Ensure preprocessing is not enabled with streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
raise ValueError("preprocess is not supported for streaming datasets")
return self
@model_validator(mode="after")
def check_dataset_mixing_weights(self):
"""Validate dataset mixing weights configuration."""
valid_strategies = ["concatenate", "round_robin", "weighted", "random"]
# Get datasets to validate length against
datasets = getattr(self, "datasets", None)
# Check main strategy and weights
strategy = getattr(self, "dataset_mixing_strategy", "concatenate")
weights = getattr(self, "mixing_weights", None)
dataset_count = len(datasets) if datasets else 0
self._validate_dataset_strategy_and_weights(
strategy,
weights,
"dataset_mixing_strategy",
"mixing_weights",
valid_strategies,
dataset_count,
)
return self
def _validate_dataset_strategy_and_weights(
self,
strategy,
weights,
strategy_field,
weights_field,
valid_strategies,
dataset_count,
):
"""Helper method to validate dataset mixing strategy and weights pair."""
if strategy not in valid_strategies:
raise ValueError(
f"{strategy_field} must be one of {valid_strategies}, "
f"got '{strategy}'"
)
if strategy == "weighted":
if weights is None:
raise ValueError(
f"{weights_field} must be provided when "
f"{strategy_field}='weighted'"
)
if not isinstance(weights, list) or not all(
isinstance(w, (int, float)) for w in weights
):
raise ValueError(f"{weights_field} must be a list of numbers")
if any(w < 0 for w in weights):
raise ValueError(f"{weights_field} must be non-negative")
if abs(sum(weights) - 1.0) > 1e-6:
raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")
# Validate weights length against dataset count
if dataset_count > 0 and len(weights) != dataset_count:
raise ValueError(
f"{weights_field} length ({len(weights)}) must match number of datasets ({dataset_count})"
)
elif weights is not None and strategy != "weighted":
LOG.warning(
f"{weights_field} provided but {strategy_field} is '{strategy}'. "
"Weights will be ignored."
)
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,
@@ -1332,6 +1539,7 @@ class ValidationMixin(
SystemValidationMixin,
ChatTemplateValidationMixin,
PretrainingValidationMixin,
StreamingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,

View File

@@ -0,0 +1,45 @@
"""Training utils for checkpoints"""
from pathlib import Path
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
update: Whether to update the config with the determined checkpoint
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
last_checkpoint = None
checkpoints = sorted(
(
p
for p in Path(cfg.output_dir).glob("checkpoint-*")
if p.name.split("-")[-1].isdigit()
),
key=lambda p: int(p.name.split("-")[-1]),
)
if checkpoints:
last_checkpoint = str(checkpoints[-1])
if not update:
return last_checkpoint
if (
cfg.resume_from_checkpoint is None
and cfg.auto_resume_from_checkpoints
and last_checkpoint is not None
):
cfg.resume_from_checkpoint = last_checkpoint
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint

View File

@@ -10,7 +10,6 @@ from typing import List, Optional
import numpy as np
import torch
import torch.cuda
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -23,6 +22,65 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False):
"""
Create a filtered IterableDataset that works around a HuggingFace datasets
limitation.
"""
def filtered_generator():
"""Generator that yields only samples that pass the filter function."""
if batched:
batch = []
batch_size = 1000 # Process in batches of 1000
for sample in dataset:
batch.append(sample)
if len(batch) >= batch_size:
# Create a batch dict from list of samples
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
# Apply filter function to batch
keep_mask = filter_fn(batch_dict)
# Yield samples that should be kept
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
batch = []
# Process remaining samples in batch
if batch:
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
keep_mask = filter_fn(batch_dict)
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
else:
# For non-batched filtering, apply filter to each sample individually
for sample in dataset:
if filter_fn(sample):
yield sample
# Create new IterableDataset from the filtered generator
filtered_dataset = IterableDataset.from_generator(filtered_generator)
# Preserve the original features if they exist
# pylint:disable=protected-access
if hasattr(dataset, "_info") and dataset._info.features is not None:
filtered_dataset._info.features = dataset._info.features
return filtered_dataset
@torch.jit.script
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
@@ -282,12 +340,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
# For IterableDatasets, always use custom filtering to avoid features issues
if isinstance(train_dataset, IterableDataset):
# IterableDatasets often have None features after transformations,
# so we use our custom filter implementation that doesn't rely on features
train_dataset = _create_filtered_iterable_dataset(
train_dataset, drop_no_trainable_tokens, batched=True
)
else:
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
@@ -472,7 +539,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
data_loader = DataLoader(
train_dataset.remove_columns(["length"]),
train_dataset,
batch_sampler=sampler,
)
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
@@ -547,7 +614,7 @@ def setup_deepspeed_env(cfg, stage=None):
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# NOTE(djsaunde): The distributed state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if (

View File

@@ -47,7 +47,9 @@ class BaseCliTest:
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
with patch(mock_fn) as mock:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
@@ -65,8 +67,12 @@ class BaseCliTest:
if train:
expected.append("--shard=False")
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
if command == "train":
assert mock.call_args.args[0] == "accelerate"
assert mock.call_args.args[1] == expected
else:
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):

View File

@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
assert result.exit_code == 0
mock_subprocess.assert_called_once()
called_cmd = mock_subprocess.call_args.args[0]
assert mock_subprocess.call_args.args[0] == "torchrun"
called_cmd = mock_subprocess.call_args.args[1]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present

View File

@@ -281,7 +281,9 @@ class TestHFRLTrainerBuilder:
# Other settings
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
assert training_arguments.gradient_checkpointing is False
# TODO(wing): restore once trl releases 0.22.0
# assert training_arguments.gradient_checkpointing is True
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "llama3",
"chat_template": "qwen3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,

View File

@@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from tests.e2e.utils import (
check_tensorboard,
require_torch_2_7_0,
require_torch_lt_2_6_0,
)
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -139,3 +143,71 @@ class TestMultiGPURay:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"save_first_step": False,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--use-ray",
"--ray-num-workers",
"2",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)

View File

@@ -1,126 +1,28 @@
"""Integration tests for FSDP Params4bit patches."""
"""Integration tests for FSDP2 Params4bit patches."""
from unittest.mock import Mock, patch
import bitsandbytes as bnb
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
patched_torch_function,
)
@pytest.fixture
def mock_params4bit():
"""Create a mock Params4bit instance with test attributes."""
mock_instance = Mock()
mock_instance.requires_grad = True
mock_instance.quant_state = "test_state"
mock_instance.blocksize = 128
mock_instance.compress_statistics = True
mock_instance.quant_type = "fp4"
mock_instance.quant_storage = "test_storage"
mock_instance.module = "test_module"
mock_instance.bnb_quantized = True
return mock_instance
class TestBnbTorchFunctionPatch:
"""Test the Params4bit.__torch_function__ patch."""
def test_apply_patch(self):
"""Test that the patch can be applied."""
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
apply_bnb_torch_function_patch()
assert hasattr(mock_cls, "__torch_function__")
assert isinstance(mock_cls.__torch_function__, classmethod)
# pylint: disable=redefined-outer-name
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
"""Test that torch.chunk preserves Params4bit attributes."""
mock_cls = Mock()
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
result = patched_torch_function(
mock_cls,
torch.chunk,
(type(mock_params4bit),),
args=(mock_params4bit, 2),
)
assert isinstance(result, tuple)
assert len(result) == 2
# Check that Params4bit constructor was called with preserved attributes
assert mock_cls.call_count == 2
for call in mock_cls.call_args_list:
kwargs = call[1]
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
assert kwargs["quant_state"] == mock_params4bit.quant_state
assert kwargs["blocksize"] == mock_params4bit.blocksize
# pylint: disable=redefined-outer-name
def test_other_functions_fallback(self, mock_params4bit):
"""Test that non-chunk/split functions use Parameter fallback."""
mock_cls = Mock()
fallback_result = torch.tensor([5, 6, 7])
with patch(
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
) as mock_fallback:
result = patched_torch_function(
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
)
# Should call Parameter.__torch_function__ and return its result
mock_fallback.assert_called_once()
assert result is fallback_result
mock_cls.assert_not_called()
class TestFSDPPatchIntegration:
"""Test FSDP patch integration."""
@pytest.mark.integration
def test_all_patches_together(self):
def test_fsdp2_init_patches(self):
"""Test that all patches can be applied together."""
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
# Store original methods before patching
original_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
# pylint: disable=protected-access
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
# Verify patches were applied
current_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
if original_torch_function is not None:
assert (
current_torch_function != original_torch_function
), "Params4bit.__torch_function__ was not patched"
else:
assert (
current_torch_function is not None
), "Params4bit.__torch_function__ was not added"
# Check that FSDP methods were patched
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param

185
tests/e2e/test_streaming.py Normal file
View File

@@ -0,0 +1,185 @@
"""E2E tests for streaming dataset functionality"""
# pylint: disable=duplicate-code
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
class TestStreamingDatasets:
"""Test case for streaming datasets with different mixing strategies"""
@pytest.mark.parametrize(
("dataset_mixing_strategy", "mixing_weights"),
[
("round_robin", None),
("weighted", [0.7, 0.3]),
("random", None),
],
)
def test_streaming_dataset_mixing_strategies(
self, temp_dir, dataset_mixing_strategy, mixing_weights
):
"""Test different mixing strategies with streaming datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": False,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3, # Very small for smoke test
"dataset_mixing_strategy": dataset_mixing_strategy,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
# Add mixing weights if specified
if mixing_weights:
cfg["mixing_weights"] = mixing_weights
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
2.5, # Loss should be reasonable for a smoke test (higher threshold for streaming)
"Train Loss (%s) is too high",
)
def test_streaming_validation_error(self, temp_dir):
"""Test that pydantic validation catches invalid streaming configs"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"streaming": True,
"max_steps": 3,
# Invalid: wrong number of weights for datasets
"dataset_mixing_strategy": "weighted",
"mixing_weights": [1.0], # Should be [0.x, 0.y] for 2 datasets
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
# This should raise a validation error
with pytest.raises(Exception) as exc_info:
validate_config(cfg)
# Verify it's the right validation error
assert "mixing_weights length" in str(exc_info.value)
assert "must match number of datasets" in str(exc_info.value)
def test_streaming_three_datasets_weighted(self, temp_dir):
"""Test weighted mixing with three datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 512,
"sample_packing": False,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
{
"path": "yahma/alpaca-cleaned",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3,
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.5, 0.3, 0.2],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
2.5,
"Train Loss (%s) is too high",
)

View File

@@ -147,7 +147,11 @@ def require_hopper(test_case):
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
temp_run_dir: str,
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
) -> None:
"""
helper function to parse and check tensorboard logs
@@ -157,6 +161,7 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
lt_val = (1 + rtol) * lt_val
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:

View File

@@ -7,13 +7,13 @@ from typing import Any, Generator
from unittest.mock import patch
import pytest
from datasets import Dataset
from datasets import Dataset, IterableDataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets, prepare_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
@@ -24,6 +24,7 @@ from tests.constants import (
from tests.hf_offline_utils import enable_hf_offline
# pylint: disable=too-many-public-methods
class TestDatasetPreparation:
"""Test a configured dataloader."""
@@ -46,6 +47,24 @@ class TestDatasetPreparation:
]
)
@pytest.fixture
def streaming_dataset_fixture(self):
"""Create a streaming dataset fixture for testing."""
def generator():
yield {
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant.",
}
yield {
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris.",
}
return IterableDataset.from_generator(generator)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
@@ -486,3 +505,162 @@ class TestDatasetPreparation:
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
def test_streaming_sft_dataset(self, tokenizer, streaming_dataset_fixture):
"""Test streaming SFT dataset preparation with IterableDataset."""
with patch("axolotl.utils.data.sft.load_dataset_with_config") as mock_load:
mock_load.return_value = streaming_dataset_fixture
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"streaming": True,
"max_steps": 100, # Required for streaming datasets
"datasets": [
{
"path": "dummy/path",
"type": "alpaca",
},
],
}
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg, tokenizer
)
# Verify it returns an IterableDataset
assert isinstance(train_dataset, IterableDataset)
assert eval_dataset is None # No eval split for streaming
assert total_num_steps == 100 # Should use max_steps
assert len(prompters) == 1
# Test that we can iterate through the dataset
sample_count = 0
for sample in train_dataset:
assert "input_ids" in sample
assert "attention_mask" in sample
assert "labels" in sample
sample_count += 1
if sample_count >= 2: # Just test first few samples
break
assert sample_count == 2
def test_dataset_mixing_strategy_validation(self):
"""Test validation of dataset mixing strategy configuration."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Test valid strategies work
valid_strategies = ["round_robin", "weighted", "random"]
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
for strategy in valid_strategies:
cfg = DictDefault(
{
"dataset_mixing_strategy": strategy,
"mixing_weights": [0.5, 0.5] if strategy == "weighted" else None,
"seed": 42,
}
)
# Should not raise an error
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
assert len(merged) >= 1
def test_regular_dataset_round_robin_mixing(self):
"""Test round-robin mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
{"text": ["ds1_item1", "ds1_item2"], "source": ["ds1", "ds1"]}
)
dataset2 = Dataset.from_dict(
{"text": ["ds2_item1", "ds2_item2"], "source": ["ds2", "ds2"]}
)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have all samples from both datasets
assert len(merged) == 4
assert isinstance(merged, Dataset)
# Check that samples are interleaved (not just concatenated)
sources = [sample["source"] for sample in merged]
# Round-robin should alternate between datasets
assert sources != ["ds1", "ds1", "ds2", "ds2"] # Not concatenated
def test_regular_dataset_weighted_mixing(self):
"""Test weighted mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
{
"text": ["ds1_item1", "ds1_item2", "ds1_item3", "ds1_item4"],
"source": ["ds1"] * 4,
}
)
dataset2 = Dataset.from_dict(
{
"text": ["ds2_item1", "ds2_item2", "ds2_item3", "ds2_item4"],
"source": ["ds2"] * 4,
}
)
cfg = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.75, 0.25], # 3:1 ratio
"seed": 42,
}
)
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have samples proportional to weights
assert len(merged) > 0
assert isinstance(merged, Dataset)
# Count samples from each dataset
sources = [sample["source"] for sample in merged]
ds1_count = sources.count("ds1")
ds2_count = sources.count("ds2")
# Should have samples from both datasets
assert ds1_count > 0 and ds2_count > 0 # Both datasets should be represented
def test_streaming_dataset_mixing(self):
"""Test that streaming datasets use HuggingFace interleave_datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test streaming datasets
def gen1():
yield {"text": "stream1_item1", "source": "stream1"}
yield {"text": "stream1_item2", "source": "stream1"}
def gen2():
yield {"text": "stream2_item1", "source": "stream2"}
yield {"text": "stream2_item2", "source": "stream2"}
stream1 = IterableDataset.from_generator(gen1)
stream2 = IterableDataset.from_generator(gen2)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_datasets_with_strategy([stream1, stream2], cfg)
# Should return an IterableDataset
assert isinstance(merged, IterableDataset)
# Test that we can iterate and get samples
samples = list(merged.take(3))
assert len(samples) >= 2 # Should get at least 2 samples
# Should have samples from both datasets
sources = [sample["source"] for sample in samples]
assert len(set(sources)) >= 1 # At least one unique source

View File

@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
)
train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(

View File

@@ -1,16 +1,11 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.train import setup_model_and_trainer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@@ -36,43 +31,6 @@ class TestPacking(unittest.TestCase):
}
)
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code

24
tests/utils/test_train.py Normal file
View File

@@ -0,0 +1,24 @@
"""test for train checkpoint utils"""
import os
from axolotl.utils.dict import DictDefault
from axolotl.utils.train import determine_last_checkpoint
def test_determine_last_checkpoint(temp_dir):
cfg = DictDefault(
output_dir=temp_dir,
)
for cpt_idx in [1, 9, 10, 20]:
os.makedirs(
os.path.join(cfg.output_dir, f"checkpoint-{cpt_idx}"), exist_ok=True
)
last_checkpoint = determine_last_checkpoint(cfg, update=False)
assert last_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")
cfg.resume_from_checkpoint = None
cfg.auto_resume_from_checkpoints = True
determine_last_checkpoint(cfg, update=True)
assert cfg.resume_from_checkpoint == os.path.join(cfg.output_dir, "checkpoint-20")