Compare commits

...

43 Commits

Author SHA1 Message Date
Wing Lian
e1c7a61243 fix reentrant when using offloading 2025-09-14 10:42:15 -04:00
salman
9640338d37 Default include_tkps to true (#3134)
* default true

* force e2e

* causal trainer only

* fix eval loggin [skip-ci]

* revert setup.py

* force tests

* guarding

* guarding

* fix test case

* use evaluate [skip-e2e]

* use evaluate [skip-e2e]

* kick off ci

* fixing

* reverting
2025-09-09 10:50:21 -04:00
Wing Lian
b5d4c7ff54 allow 1% deviation for codecov (#3138) [skip ci] 2025-09-07 11:01:03 -04:00
Seungduk Kim
8fd9221f13 Add ipo as an rl type that shares DPODataset config (#3128)
* Add `ipo` as an `rl` type that shares DPODataset config

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-09-07 10:49:10 -04:00
github-actions[bot]
bf00f29f3a chore: update pre-commit hooks (#3137) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-09-07 10:33:20 -04:00
NanoCode012
1d32278755 feat: upgrade transformers to v4.56.1 (#3127)
* feat: upgrade transformers to v4.56

* fix handling of CP/SP now that position_ids are default even for unpacked sequences

* feat: monkeypatch list_repo_templates

* fix: apply patch for tests only

* see if updated main works at least

* fix: update to patch release and remove monkeypatch

* remove fsdp2 eval patch

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-09-05 11:00:54 -04:00
NanoCode012
c6ae5c43cb fix: chat template jinja file not being loaded during inference (#3112)
* fix: chat template jinja file not being loaded during inference

* fix: bot comment
2025-09-03 16:25:09 -04:00
yardenhoch
efa1da52d5 Center rewards coefficient (#3124)
* feat: add center_rewards_coefficient for reward modeling

- Add center_rewards_coefficient parameter to Pydantic schema with paper reference
- Pass parameter through base builder and causal builder to training args
- Add documentation section with usage examples and theoretical background
- Enable parameter in reward modeling example configs with recommended value
- Enables reward centering for improved training stability in RLHF workflows

Implements auxiliary loss from Eisenstein et al. 2023 (https://huggingface.co/papers/2312.09244)
to incentivize mean-zero reward outputs without post-training normalization.

* Update description

* test: add unit tests for center_rewards_coefficient integration

* Update src/axolotl/core/builders/base.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update docs/reward_modelling.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* reference to TRL documentation.

* add new reward model configuration for qwen3 with comprehensive parameters

* Verified center_rewards_coefficient is correctly passed through the trainer builder to training arguments.

* Refactor reward modeling documentation to consolidate information on center_rewards_coefficient

* Remove unit tests for center_rewards_coefficient integration as part of codebase cleanup.

* linting

* nit

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-09-03 16:22:37 -04:00
mhenrichsen
48db520d92 Create 270m-qlora.yml (#3075) [skip ci]
Adds 270m gemma3 qlora
2025-09-03 16:20:32 -04:00
NanoCode012
53a0c1f39c feat: add peft_trainable_token_indices (#3062)
* feat: add peft_trainable_token_indices

* feat: add warning compat with fix_untrained_tokens
2025-09-03 01:48:01 -04:00
github-actions[bot]
4cc6038d52 chore: update pre-commit hooks (#3122) [skip ci]
Co-authored-by: djsaunde <1245942+djsaunde@users.noreply.github.com>
2025-09-03 01:41:34 -04:00
NanoCode012
e48aa8a5b1 feat(doc): improve visibility for colab notebooks (#3110) [skip ci]
* feat: improve visibility for colab notebooks

* fix: link to GH colab

* feat: change to badge and move higher
2025-09-03 01:40:53 -04:00
xuyifann
24aba5caca Clamping the len of dataloader to minimum of 1 (#3100) [skip ci]
* Clamping the len of dataloader to minimum of 1

* linter reformat
2025-09-03 01:40:27 -04:00
Wing Lian
06bebcb65f run cu128-2.8.0 e2e tests on B200 (#3126)
* run cu128-2.8.0 e2e tests on B200

* not an int 🤦

* fix yaml
2025-09-02 13:13:23 -04:00
Dan Saunders
231a67e70b Streaming SFT support (#3101)
* working

* fixes

* deprecate --iterable; cleanup

* pretrain_multipack_buffer_size -> streaming_multipack_buffer_size

* improvements

* tests

* remove unused

* docs, examples

* nit

* nit

* add val_set_size validation

* val

* nit

* min

* coderabbito

* cleanup

* nit

* add depr warning, cleanup

* nit

* fix test, fix quarto

* fix

* review comments

* review comments

* fix
2025-09-02 12:08:44 -04:00
Wing Lian
0094a2d744 support for tiledmlp for GPT-OSS (#3116)
* fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS

* add logging back

* update deps
2025-08-29 13:52:49 -04:00
Wing Lian
7ed40f1d70 automatically set env vars for single gpu deepspeed zero3 (#3118) [skip ci]
* automatically set env vars for single gpu deepspeed zero3

* use setdefault
2025-08-29 13:36:47 -04:00
VED
5b6ec2820f patch for ds_grads_remaining in deepspeed (#3102) [skip ci]
* patch deepspeed

* deepspeed patch for ds_grads_remaining

* patch in Patchmanager

* chore: lint

* deepseed utils

* chore2

* patch ds_grads_remaining chore

* chore lint

* chore lint

* remove torch.nn patch

* lint

* Update src/axolotl/monkeypatch/utils.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* patched with checkpointwarapper

* lint

* only apply deepspeed patch when using activation offloading

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-29 12:12:09 -04:00
Wing Lian
6afba3871d Add support for PyTorch 2.8.0 (#3106)
* Add support for PyTorch 2.8.0

* loosen triton requirements

* handle torch 2.8.0 in setup.py

* fix versions

* no vllm for torch 2.8.0

* remove comment

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-28 09:10:40 -04:00
Dan Saunders
dc338c3b0e Update .coderabbit.yaml (#3109) [skip ci]
Oops, should be false.
2025-08-27 09:50:52 -04:00
salman
d0d2fc5606 Tokens per second logging [skip-e2e] (#3072) 2025-08-27 09:10:14 +01:00
Wing Lian
e1131e9619 make always skip_move_to_device default as true (#3084) 2025-08-26 09:30:22 -04:00
Wing Lian
c4c4b90638 add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
2025-08-26 09:30:04 -04:00
Wing Lian
0e9945e3b9 deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci]
* deploy training jobs to baseten w truss in axolotl cli

* cleanup
2025-08-26 09:29:50 -04:00
NanoCode012
0de254a0d0 feat: add gemma3_text attention handling for lora kernels (#3103) 2025-08-26 16:47:26 +07:00
Dan Saunders
79ddaebe9a Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
2025-08-23 23:37:33 -04:00
Dan Saunders
eea7a006e1 make multipack sampler patch explicit (#3096)
* make multipack sampler patch explicit

* combining
2025-08-22 14:29:10 -04:00
Wing Lian
ab4d604a8f upgrade peft for 0.17.1 (#3094)
* upgrade peft to 0.17.1

* upgrade for transformers too
2025-08-22 07:26:30 -04:00
Wing Lian
0fa752e58b upgrade flash-attn to 2.8.3 for gpt-oss attn sink support (#3082) 2025-08-21 15:04:10 -04:00
Dan Saunders
08e517ea48 Update .coderabbit.yaml (#3091) [skip ci] 2025-08-20 22:14:13 -04:00
Wing Lian
07fd22f39b better handling of lora w bias with fsdp2 and handling of files when saving model checkpoint (#3090) 2025-08-20 15:17:48 -04:00
Wing Lian
06eaf6c448 misc fixes (#3085) 2025-08-20 08:52:26 -04:00
goggle
050210e637 fix: Sweep runs overwrite each other because output_dir from base config is reused (#3080)
* refactor: improve output_dir handling in generate_config_files

* fix typo

* cli: harden sweep output_dir handling with base fallback

- Ensure sweep permutations always resolve a valid output_dir
- Default to ./model-out if neither permutation nor base config sets output_dir
- Append sweepXXXX suffix consistently for each permutation
- Prevent Path(None) TypeError and improve robustness of sweep config generation

* fix typo

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-19 20:25:20 -04:00
Wing Lian
05cedbfb1e add baseten info for gpt-oss recipe (#3078)
* add bsaeten info for gpt-oss recipe

* incorporate PR review
2025-08-19 13:30:37 -04:00
VED
c10eb811fa data_parallel_size in in VllmserveCliArgs (#3074)
* data_parallel_size in in VllmserveCliArgs

* moved to 43
2025-08-18 08:44:37 -04:00
VED
0eef385b1a [feat] truncation support with excess_length_strategy (#3068) [skip ci]
* feat:truncation support with excess_len

* pre-commit

* excess_length_strategy

* requested changes

* lint

* added handle_long_seq_in_dataset in sft

* comments improved
2025-08-18 08:39:13 -04:00
Wing Lian
ecbe8b2b61 [GPT-OSS] improve FSDP shard merging and documentation for GPT-OSS (#3073)
* improve fsdp shard merging

* improve logging

* update information on merging and inferencing GPT-OSS

* cleanup readme

* automate cleanup of FSDP prefix

* import GRPO only if necessary

* only modify config.json on rank0

* merge final checkpoint at end of training

* prevent circular import

* Fix saving for sharded state dict

* devx, move merged to output dir

* move import back to top

* Fix stuck merge

* fix conditionals from pr feedback and add test
2025-08-15 21:25:01 -04:00
Wing Lian
130ef7c51a Various fixes for VLMs (#3063)
* fix to not use batch feature indexing

* more vlm fixes

* use AutoModelForImageTextToText

* add example yaml and need num2words for chat template

* improve handling of adding image tokens to conversation

* add lfm2-vl support

* update the lfm readme

* fix markdown and add rtol for loss checks

* feat: add smolvlm2 processing strat

* fix: check for causal-conv1d in lfm models

* feat: add docs for lfm2

* feat: add new models and tips to docs

* feat: add smolvlm2 docs and remove extra dep

* chore: update docs

* feat: add video instructions

* chore: cleanup

* chore: comments

* fix: typo

* feat: add usage stats

* chore: refactor

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-15 10:52:57 -04:00
salman
d1de6f5f3d Add option to skip slow tests in PRs (#3060) [skip ci]
* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* testing e2e skip [skip-e2e]

* stop running multigpu [skip-e2e]

* should work now [skip-e2e]

* reverting [skip-e2e]

* testing [skip-e2e]

* debug [skip-e2e]

* debug [skip-e2e]

* round 2[skip-e2e]

* removing debug [skip-e2e]

* support skipping whole PR [skip-e2e]

* use script for e2e skip [skip-e2e]

* contributing [skip-e2e]

* contributing [skip-e2e]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-08-13 22:57:51 -04:00
Wing Lian
48b7ae1677 use updated patch releasE (#3066) 2025-08-13 21:23:05 -04:00
NanoCode012
506e3a3907 fix: fsdp_config validation being None (#3061) [skip ci]
* fix: fsdp_config validation being None

* fix: handling

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-08-13 21:21:50 -04:00
Wing Lian
09145de8fa upgrade transformers==4.55.1 and bitsandbytes==0.47.0 (#3064)
* upgrade transformers==4.55.1

* also upgrade bnb

* remove bnb params4bit patch (upstreamed)

* use latest causal-conv1d

* fix patching ring-flash-attn with now missing imports

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-08-13 19:41:07 -04:00
Wing Lian
e0a2523a3b Workaround to unblock docs build in main (#3055)
Co-authored-by: Salman Mohammadi <salman.mohammadi@outlook.com>
2025-08-13 11:39:39 +01:00
349 changed files with 13515 additions and 12356 deletions

View File

@@ -1,3 +1,3 @@
[bandit]
exclude = tests
skips = B101,B615
skips = B101,B615,B102,B110

View File

@@ -12,5 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: false
chat:
auto_reply: true

View File

@@ -1,5 +0,0 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

View File

@@ -57,6 +57,13 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
#### Skipping CI Checks
You can skip certain CI checks by including specific keywords in your commit messages:
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
## Style Guidelines
### Code Style

View File

@@ -36,6 +36,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -110,6 +115,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -169,6 +179,12 @@ jobs:
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -33,13 +33,6 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -47,6 +40,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -188,13 +188,44 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
steps:
- uses: actions/github-script@v7
id: compute
with:
script: |
const token = /\[skip-e2e\]/i;
let msg = '';
if (context.eventName === 'push') {
msg = context.payload.head_commit?.message || '';
} else if (context.eventName === 'pull_request') {
const { owner, repo } = context.repo;
const prNumber = context.payload.pull_request.number;
const commits = await github.paginate(
github.rest.pulls.listCommits,
{ owner, repo, pull_number: prNumber, per_page: 100 }
);
msg = commits.at(-1)?.commit?.message || '';
}
const title = context.payload.pull_request?.title || '';
const body = context.payload.pull_request?.body || '';
const skip = token.test(msg) || token.test(title) || token.test(body);
core.setOutput('skip', String(skip));
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist]
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
strategy:
fail-fast: false
@@ -209,7 +240,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -240,13 +271,16 @@ jobs:
modal run cicd.e2e_tests
docker-e2e-tests:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
strategy:
fail-fast: false
@@ -264,6 +298,13 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -284,6 +325,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
@@ -300,10 +342,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -1,4 +0,0 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,22 +10,12 @@ repos:
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 25.1.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.12
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.8
hooks:
- id: pylint
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
hooks:

View File

@@ -1,15 +0,0 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -17,6 +17,7 @@
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
@@ -70,6 +71,10 @@ Features:
- Python 3.11
- PyTorch ≥2.6.0
### Google Colab
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
### Installation
#### Using pip

10
TODO.md
View File

@@ -1,10 +0,0 @@
# todo list
- [] Validation of parameters for combinations that won't work
## things that are known not to work
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
- adamw_bnb_8bit doesn't play well with FSDP offload

View File

@@ -153,7 +153,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.streaming
- utils.data.sft
- utils.quantization
- title: Schemas
@@ -272,6 +272,7 @@ website:
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd

View File

@@ -2,8 +2,6 @@
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -63,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
exit(exit_code)
@app.function(

View File

@@ -1,7 +1,5 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -59,7 +57,8 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = f"L40S:{N_GPUS}"
GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
@@ -70,4 +69,4 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
exit(exit_code)

View File

@@ -12,7 +12,7 @@ coverage:
default:
# basic
target: auto
threshold: 0%
threshold: 1%
base: auto
# advanced
branches: null
@@ -27,7 +27,7 @@ coverage:
default:
# basic
target: auto
threshold: 0%
threshold: 1%
base: auto
# advanced
branches: null

View File

@@ -37,7 +37,7 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge

View File

@@ -134,7 +134,7 @@ For providers supporting Docker:
### Google Colab {#sec-colab}
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
## Platform-Specific Instructions {#sec-platform-specific}

View File

@@ -63,15 +63,6 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
::: {.callout-tip}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}

View File

@@ -13,10 +13,13 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
## Usage
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal
chat_template: # see in next section
chat_template: # see in next section if specified
# example dataset
datasets:
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken
```
### Voxtral {#sec-voxtral}
::: {.callout-tip}
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
:::
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
```
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
:::
```yaml
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
```
### LFM2-VL {#sec-lfm2-vl}
::: {.callout-warning}
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
:::
```yaml
base_model: LiquidAI/LFM2-VL-450M
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::
### Video
::: {.callout-warning}
This is not well tested at the moment. We welcome contributors!
:::
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
- `"path": "/path/to/video.mp4"`
- `"url": "https://example.com/video.mp4"`
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
### Example
Here is an example of a multi-modal dataset:

View File

@@ -11,6 +11,7 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml
base_model: google/gemma-2-2b

View File

@@ -47,7 +47,6 @@ class QuartoGenerator:
"""Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+)
@@ -124,7 +123,6 @@ class QuartoGenerator:
return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type(
self, field_type
) -> list[type[BaseModel]]:
@@ -318,7 +316,6 @@ class QuartoGenerator:
return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source(
self, model_class: type[BaseModel]
) -> list[dict]:
@@ -503,7 +500,7 @@ class QuartoGenerator:
nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", [])
except Exception: # pylint: disable=broad-exception-caught
except Exception:
# Fallback: use model fields directly
nested_properties = {}
nested_required = []
@@ -607,7 +604,7 @@ class QuartoGenerator:
schema = model_class.model_json_schema()
properties = schema.get("properties", {})
required = schema.get("required", [])
except Exception as e: # pylint: disable=broad-exception-caught
except Exception as e:
print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
)

120
docs/streaming.qmd Normal file
View File

@@ -0,0 +1,120 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -0,0 +1,58 @@
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
**LFM2**
```bash
# FFT SFT (1x48GB @ 25GiB)
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
```
**LFM2-VL**
```bash
# LoRA SFT (1x48GB @ 2.7GiB)
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
- **Dataset Formats**:
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens:
- "<|im_end|>"
datasets:

View File

@@ -0,0 +1,58 @@
base_model: LiquidAI/LFM2-VL-450M
trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,10 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
# huggingface repo
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -33,13 +33,64 @@ Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
### Training 120B
On 8xH100s
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
```bash
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
```bash
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
```
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
weights to `{output_dir}/merged`.
```bash
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
```
### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
```bash
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
```
### Tool use
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.

View File

@@ -20,6 +20,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
sequence_len: 4096
sample_packing: true
@@ -43,7 +44,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -1,7 +0,0 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -0,0 +1,44 @@
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
chat_template: qwen3
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
deepspeed: deepspeed_configs/zero1.json
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: linear
learning_rate: 0.00002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_ratio: 0.1
logging_steps: 1
weight_decay: 0.01

View File

@@ -0,0 +1,49 @@
# Finetune SmolVLM2 with Axolotl
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
pip3 install num2words==0.5.14
```
3. Run the finetuning example:
```bash
# LoRA SFT (1x48GB @ 6.8GiB)
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
```
## TIPS
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,56 @@
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
trust_remote_code: true
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,50 @@
# Streaming Dataset Examples
This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.
## Examples
Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!
### Pretraining (`pretrain.yaml`)
Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.
- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management
### SFT (`sft.yaml`)
Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.
- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed
## Key Configuration Options
### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`
### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory
### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer
### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens
## Performance Tips
- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency

View File

@@ -0,0 +1,57 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train
# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,55 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -26,3 +26,34 @@ include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long
"C901", # Too complex
"B019", # Use of functools.cache on methods
"E722", # Bare except
"F821", # Undefined name (for dynamic exec)
]
[tool.ruff.lint.isort]
known-third-party = ["wandb", "comet_ml"]
known-local-folder = ["src", "tests"]
# Black-compatible isort settings
force-single-line = false
combine-as-imports = true
split-on-trailing-comma = true
[tool.ruff.format]
# Use black's formatting style exactly
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false

View File

@@ -1,9 +1,8 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.1
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
bitsandbytes==0.47.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
@@ -13,8 +12,8 @@ liger-kernel==0.6.1
packaging==23.2
huggingface_hub>=0.33.0
peft==0.17.0
transformers==4.55.0
peft>=0.17.0
transformers==4.56.1
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0

View File

@@ -27,7 +27,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not field_messages:
raise ValueError(
f'No conversation field found in dataset: {", ".join(feature_keys)}'
f"No conversation field found in dataset: {', '.join(feature_keys)}"
)
ds_cfg["field_messages"] = field_messages
@@ -40,7 +40,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["role"]:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
f"No role field found in messages: {', '.join(message_fields)}"
)
for key in ["content", "text", "value"]:
@@ -49,7 +49,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["content"]:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
f"No content field found in messages: {', '.join(message_fields)}"
)
ds_cfg["message_property_mappings"] = message_property_mappings

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
)

View File

@@ -1,11 +1,10 @@
# noqa
# pylint: skip-file
import sys
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]

View File

@@ -64,7 +64,9 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 7):
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
@@ -118,14 +120,14 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.2"],
"flash-attn": ["flash-attn==2.8.3"],
"ring-flash-attn": [
"flash-attn==2.8.2",
"flash-attn==2.8.3",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.2",
"deepspeed==0.17.5",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -14,9 +14,13 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
default=False,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
"help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
},
)
@@ -40,6 +44,12 @@ class VllmServeCliArgs:
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},

View File

@@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO # pylint: disable=global-statement
global HAS_PRINTED_LOGO
if HAS_PRINTED_LOGO:
return
if is_main_process():

View File

@@ -7,6 +7,8 @@ from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
@@ -38,8 +40,15 @@ def do_cli_train(
cwd=None,
**kwargs,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}

View File

@@ -0,0 +1,48 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -0,0 +1,71 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code) # pylint: disable=consider-using-sys-exit
exit(exit_code)
# Commit writes to volume.
if volumes:
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.6.0"
docker_tag = "main-py3.11-cu126-2.7.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -130,7 +130,6 @@ class ModalCloud(Cloud):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
@@ -177,8 +176,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
@@ -187,7 +186,7 @@ class ModalCloud(Cloud):
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements
def get_train_gpu(self):
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
@@ -200,7 +199,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
return f"H100:{count}"
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":
@@ -277,7 +276,7 @@ def _train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
volumes=None,
**kwargs, # pylint: disable=unused-argument
**kwargs,
):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:

View File

@@ -210,7 +210,7 @@ def load_cfg(
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
except:
gpu_version = None
prepare_plugins(cfg)

View File

@@ -28,7 +28,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# pylint: disable=duplicate-code
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
@@ -49,7 +49,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -14,10 +14,7 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -35,7 +32,7 @@ def get_multi_line_input() -> str:
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
instruction += line
return instruction
@@ -64,7 +61,9 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
@@ -159,7 +158,13 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -167,7 +172,6 @@ def do_inference_gradio(
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
@@ -252,7 +256,7 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs)

View File

@@ -1,7 +1,5 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import os
import subprocess # nosec B404
from typing import Literal, Optional

View File

@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -10,6 +10,7 @@ import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
from axolotl.utils.train import determine_last_checkpoint
LOG = get_logger(__name__)
@@ -30,7 +32,7 @@ LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
def commit_tensor(self, read_item, tensor):
tensor.copy_(tensor.to(torch.bfloat16))
@@ -57,10 +59,10 @@ def _distributed_checkpoint_to_merged_weights(
state_dict: Dict = {}
save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
dist_cp_format_utils._load_state_dict(
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
planner=BFloat16CastPlanner(),
no_dist=True,
)
@@ -143,7 +145,6 @@ def merge_fsdp_weights(
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
@@ -180,7 +181,6 @@ def merge_fsdp_weights(
if remove_checkpoint_dir:
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
shutil.rmtree(checkpoint_dir_)
state.wait_for_everyone()
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
@@ -191,15 +191,36 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
if checkpoint_dir:
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
raise ValueError(
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
)
output_path = str(Path(parsed_cfg.output_dir) / "merged")
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
output_path=output_path,
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()
LOG.info(
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
main_process_only=True,
)
LOG.info(
"Merged weights are only the safetensors and doesn't include the model configuration "
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
main_process_only=True,
)
if __name__ == "__main__":

View File

@@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
)
return
@@ -73,7 +83,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
except Exception: # nosec B110
pass
# fmt: on
@@ -95,9 +105,10 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs)
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -84,5 +84,6 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -59,7 +59,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -65,7 +65,7 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type)
if field_type == bool:
if field_type is bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -103,7 +103,7 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
if field_type == bool:
if field_type is bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(

View File

@@ -3,11 +3,12 @@
import random
from copy import deepcopy
from itertools import product
from typing import Any
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, list]]:
) -> list[dict[str, Any]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.
@@ -48,7 +49,10 @@ def generate_sweep_configs(
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
full_combo = {
**dict(zip(param_names, reg_combo, strict=False)),
**paired_set,
}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
@@ -57,7 +61,7 @@ def generate_sweep_configs(
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
for param_name, param_value in zip(param_names, reg_combo, strict=False):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)

View File

@@ -4,6 +4,7 @@ import os
import subprocess # nosec
import sys
import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal
import yaml
@@ -67,14 +68,12 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
"""
Generate list of configuration files to process.
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
whether this is a group of configurations (i.e., a sweep).
Args:
config: Base configuration file
sweep: Sweep configuration file
Yields:
Tuple of configuration file name and whether this is a group of configurations
"""
if not sweep:
@@ -90,8 +89,12 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
for permutation in permutations:
# pylint: disable=consider-using-with
base_output_dir = base_config.get("output_dir", "./model-out")
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
temp_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",

View File

@@ -39,7 +39,7 @@ def do_vllm_serve(
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -68,7 +68,6 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
# pylint: disable=unexpected-keyword-arg
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,

View File

@@ -6,7 +6,7 @@ from dataclasses import dataclass
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
import axolotl.monkeypatch.data.batch_dataset_fetcher # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
@@ -55,13 +55,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -67,9 +67,7 @@ class JsonToJsonlConverter:
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(
self, input_file_path, output_file_path
): # pylint: disable=unused-argument
def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -84,9 +84,7 @@ def create_causal_mask(
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
if attention_mask is not None:
def causal_doc_mask_mod(
batch_idx, head_idx, q_idx, kv_idx
): # pylint: disable=unused-argument
def causal_doc_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
@@ -103,9 +101,7 @@ def create_causal_mask(
mask_factory_function = causal_doc_mask_mod
else:
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[
config._attn_implementation # pylint: disable=protected-access
]
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Do not allow skip if we are compiling (this is to match BC)
allow_is_causal_skip = (

View File

@@ -24,9 +24,7 @@ from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -44,7 +42,7 @@ from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
with suppress(ImportError):
import torch._dynamo # pylint: disable=ungrouped-imports
import torch._dynamo
class TrainerBuilderBase(abc.ABC):
@@ -260,14 +258,14 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
from axolotl.contribs.mit.muon import (
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
from axolotl.contribs.mit.dion import (
DionOptimizerFactory,
)
@@ -414,12 +412,8 @@ class TrainerBuilderBase(abc.ABC):
def _configure_torch_compile(self, training_args_kwargs: dict):
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.accumulated_cache_size_limit = 256
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = (
@@ -516,6 +510,7 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -7,10 +7,7 @@ from pathlib import Path
from typing import Type, Union
import transformers
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
@@ -26,12 +23,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
@@ -42,6 +39,7 @@ from axolotl.utils.collators import (
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
@@ -74,6 +72,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -340,20 +344,22 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
if self.cfg.center_rewards_coefficient is not None:
training_arguments_kwargs["center_rewards_coefficient"] = (
self.cfg.center_rewards_coefficient
)
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
training_args = training_args_cls(
**training_arguments_kwargs,
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
training_args.run_name = None
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
@@ -406,6 +412,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

View File

@@ -168,16 +168,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
training_args = training_args_cls(
logging_first_step=True,
**training_args_kwargs,
)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
training_args.run_name = None
return training_args, trainer_kwargs

View File

@@ -10,7 +10,7 @@ from .shared import wrap_tools
def format_message(
message: Messages,
message_index: Optional[int] = None, # pylint: disable=unused-argument
message_index: Optional[int] = None,
) -> Messages:
if message.is_chat_formatted:
return message

View File

@@ -15,11 +15,11 @@ class MessageRoles(str, Enum):
Message roles for the system, user, assistant, and tools
"""
system = "system" # pylint: disable=invalid-name
user = "user" # pylint: disable=invalid-name
assistant = "assistant" # pylint: disable=invalid-name
tool = "tool" # pylint: disable=invalid-name
ipython = ( # pylint: disable=invalid-name
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
ipython = (
# for responses from builtin tools
"ipython"
)
@@ -30,12 +30,12 @@ class MessageContentTypes(str, Enum):
Message content types for text, image, audio, tool calls, and tool responses
"""
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
text = "text" # pylint: disable=invalid-name
image = "image" # pylint: disable=invalid-name
audio = "audio" # pylint: disable=invalid-name
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
tool_response = "tool_response" # pylint: disable=invalid-name
special_token = "special_token" # nosec B105
text = "text"
image = "image"
audio = "audio"
tool_call = "tool_call"
tool_response = "tool_response"
class SpecialToken(str, Enum):
@@ -43,8 +43,8 @@ class SpecialToken(str, Enum):
Special tokens for beginning of string and end of string
"""
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
bos_token = "bos_token" # nosec B105
eos_token = "eos_token" # nosec B105
class ToolCallFunction(BaseModel):
@@ -73,7 +73,7 @@ class ToolCallContents(BaseModel):
name: str
arguments: dict[str, Union[str, int]]
id: Optional[str] = None # pylint: disable=invalid-name
id: Optional[str] = None
def __str__(self) -> str:
data = {"name": self.name, "arguments": self.arguments}
@@ -89,7 +89,7 @@ class ToolResponseContents(BaseModel):
name: str
content: Union[str, dict[str, Union[str, int, float]]]
id: Optional[str] = None # pylint: disable=invalid-name
id: Optional[str] = None
def __str__(self) -> str:
data = {"name": self.name, "content": self.content}

View File

@@ -1,23 +1,17 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
This module contains a function that builds a transform that takes a row from the
dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union
from typing import Any, Mapping
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
def chat_message_transform_builder(
train_on_inputs=False,
conversations_field: str = "conversations",
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
message_field_content: Union[str, list[str]] = [
"value",
"text",
"content",
], # commonly "content"
message_field_training: Union[str, list[str]] = [
"train",
"weight",
], # commonly "weight"
message_field_role: str | list[str] | None = None, # commonly "role"
message_field_content: str | list[str] | None = None, # commonly "content"
message_field_training: str | list[str] | None = None, # commonly "weight"
):
"""Builds a transform that takes a row from the dataset and converts it to a Chat
@@ -39,6 +33,12 @@ def chat_message_transform_builder( # pylint: disable=dangerous-default-value
A function that takes a list of conversations and returns a list of messages.
"""
if message_field_training is None:
message_field_training = ["train", "weight"]
if message_field_content is None:
message_field_content = ["value", "text", "content"]
if message_field_role is None:
message_field_role = ["role", "from"]
message_field_role = (
[message_field_role]
if isinstance(message_field_role, str)

View File

@@ -1,11 +1,9 @@
"""Init for axolotl.core.trainers"""
# pylint: disable=unused-import
# flake8: noqa
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .trl import (
AxolotlCPOTrainer,

View File

@@ -1,7 +1,5 @@
"""Module for customized trainers"""
# pylint: disable=too-many-lines
from __future__ import annotations
import os
@@ -44,6 +42,7 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -65,6 +64,15 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -80,7 +88,6 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
@@ -285,9 +292,9 @@ class AxolotlTrainer(
# fmt: off
if dataloader_key is not None and self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore # pylint: disable=access-member-before-definition
self._eval_dataloaders[dataloader_key] = dataloader # type: ignore
else:
self._eval_dataloaders = {dataloader_key: dataloader} # pylint: disable=attribute-defined-outside-init
self._eval_dataloaders = {dataloader_key: dataloader}
# fmt: on
return self.accelerator.prepare(dataloader)
@@ -329,6 +336,17 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -443,7 +461,7 @@ class AxolotlTrainer(
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
num_items_in_batch=None,
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
@@ -524,15 +542,10 @@ class AxolotlTrainer(
accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access
reset_partial_state=True
)
AcceleratorState._reset_state(reset_partial_state=True)
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -540,7 +553,6 @@ class AxolotlTrainer(
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
) -> dict[str, Any]:
@@ -581,12 +593,19 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
@@ -662,6 +681,11 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -101,11 +101,11 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: str = self.loss_type # type: ignore[has-type] # pylint: disable=access-member-before-definition
loss_type: str = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = "ipo" # pylint: disable=attribute-defined-outside-init
self.loss_type = "ipo"
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type # pylint: disable=attribute-defined-outside-init
self.loss_type = loss_type
return res
return super().concatenated_forward(model, batch, is_ref_model=is_ref_model)

View File

@@ -128,9 +128,7 @@ class GRPOStrategy:
return grpo_args_kwargs
@classmethod
def set_trainer_args(
cls, cfg: DictDefault
) -> list[Any]: # pylint: disable=unused-argument
def set_trainer_args(cls, cfg: DictDefault) -> list[Any]:
trainer_args = []
if cfg.trl and cfg.trl.reward_funcs:
reward_funcs = []
@@ -151,7 +149,7 @@ class GRPOStrategy:
return trainer_kwargs
@classmethod
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
def get_collator(cls, *args, **kwargs):
# No data collation is needed in GRPO, handled by trl's trainer __init__
return None

View File

@@ -1,7 +1,5 @@
"""Axolotl GRPO trainers (with and without sequence parallelism handling)"""
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from typing import Any
@@ -52,7 +50,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
@@ -253,7 +250,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
def get_train_dataloader(self) -> DataLoader:
"""Get dataloader for training"""
train_dataset = self.train_dataset
# pylint: disable=access-member-before-definition
data_collator = self.data_collator # type: ignore
# Handle dataset preprocessing
@@ -266,7 +263,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
train_dataset, description="training"
)
else:
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
self.data_collator = self._get_collator_with_removed_columns(
data_collator,
description="training",
)
@@ -308,10 +305,10 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
# pylint: disable=access-member-before-definition
if self.state.global_step != self._last_loaded_step: # type: ignore[has-type]
self._move_model_to_vllm()
# pylint: disable=attribute-defined-outside-init
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
@@ -333,8 +330,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
# Extract prompts from this SP group, accounting for num_generations duplicates
# We only need prompts from one rank in each SP group
group_prompts = all_prompts_text[
group_leader_rank
* len(prompts_text) : (group_leader_rank + 1)
group_leader_rank * len(prompts_text) : (
group_leader_rank + 1
)
* len(prompts_text) : self.num_generations
]
@@ -485,7 +483,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
for prompt, completion in zip(prompts, completions_text, strict=False):
bootstrap = (
prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
)
@@ -503,6 +501,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
self.reward_funcs,
self.reward_processing_classes,
self.reward_func_names,
strict=False,
)
):
with profiling_context(self, reward_func_name):
@@ -511,14 +510,17 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [
{"messages": p + c} for p, c in zip(prompts, completions)
{"messages": p + c}
for p, c in zip(prompts, completions, strict=False)
]
texts = [
apply_chat_template(x, reward_processing_class)["text"]
for x in messages
]
else:
texts = [p + c for p, c in zip(prompts, completions)]
texts = [
p + c for p, c in zip(prompts, completions, strict=False)
]
reward_inputs = reward_processing_class(
text=texts,
return_tensors="pt",
@@ -564,7 +566,8 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
row_reward_kwargs["completion"] = completions[nan_row_idx]
warnings.warn(
f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
"Please ensure that at least one reward function returns a valid reward."
"Please ensure that at least one reward function returns a valid reward.",
stacklevel=2,
)
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the

View File

@@ -5,7 +5,6 @@ import torch
from axolotl.core.trainers.base import AxolotlTrainer
# pylint: disable=too-many-ancestors
class AxolotlMambaTrainer(AxolotlTrainer):
"""Mamba specific trainer to handle loss calculation"""
@@ -15,8 +14,8 @@ class AxolotlMambaTrainer(AxolotlTrainer):
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
return_outputs=False,
num_items_in_batch=None,
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

View File

@@ -1,6 +1,5 @@
"""Init for axolotl.core.trainers.mixins"""
# pylint: disable=unused-import
# flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin

View File

@@ -3,11 +3,14 @@ Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from functools import partial
from peft import PeftModel
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
@@ -46,9 +49,20 @@ class ActivationOffloadingMixin(Trainer):
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, **kwargs):
def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
if use_reentrant:
checkpoint_wrapper_fn = partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
)
else:
checkpoint_wrapper_fn = checkpoint_wrapper
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
auto_wrap_policy=auto_wrap_policy,
**kwargs,
)
def get_lora_act_offloading_ctx_manager(
@@ -92,7 +106,7 @@ def get_lora_act_offloading_ctx_manager(
`contextlib.ContextDecorator`:
Activation offloading context manager for the model.
"""
# pylint: disable=unnecessary-dunder-call
activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory,
use_streams=use_streams,

View File

@@ -26,7 +26,6 @@ class DistributedParallelMixin(Trainer):
self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None
):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"

View File

@@ -70,11 +70,11 @@ class OptimizerMixin(Trainer):
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
lr = optimizer_kwargs["lr"]
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
lr *= self.args.embedding_lr_scale
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
lr = self.args.embedding_lr
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
@@ -143,7 +143,7 @@ class OptimizerMixin(Trainer):
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer = create_loraplus_optimizer(
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
@@ -185,17 +185,15 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
LOG.info(f"skipped {module}: {skipped / 2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
LOG.info(f"skipped: {skipped / 2**20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer

View File

@@ -46,7 +46,7 @@ class SchedulerMixin(Trainer):
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
if self.lr_scheduler is None: # type: ignore
# fmt: on
plugin_manager = PluginManager.get_instance()
lr_scheduler: LRScheduler | None = plugin_manager.create_lr_scheduler(
@@ -90,7 +90,7 @@ class SchedulerMixin(Trainer):
LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup(
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -98,7 +98,7 @@ class SchedulerMixin(Trainer):
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant(
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -107,7 +107,7 @@ class SchedulerMixin(Trainer):
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = get_cosine_schedule_with_min_lr(
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
@@ -133,7 +133,7 @@ class SchedulerMixin(Trainer):
)
if not self.lr_scheduler:
super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
self.lr_scheduler = JaggedLRRestartScheduler(
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,

View File

@@ -14,7 +14,6 @@ class AxolotlTrainingMixins:
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
@@ -50,6 +49,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -1,18 +1,17 @@
"""Module containing Dataset functionality"""
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -26,7 +25,7 @@ class TokenizedPromptDataset(Dataset):
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""
def __init__( # pylint: disable=super-init-not-called
def __init__(
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
@@ -86,133 +85,3 @@ def wrap_dataset_for_tokenized_prompt(
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -79,7 +79,7 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
model, tokenizer, _, processor = setup_model_and_tokenizer(cfg)
# Get datasets
# pylint: disable=duplicate-code
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps

View File

@@ -76,7 +76,7 @@ class BasePlugin:
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg: dict): # pylint: disable=unused-argument
def register(self, cfg: dict):
"""Registers the plugin with the given configuration as an unparsed dict.
Args:
@@ -104,14 +104,13 @@ class BasePlugin:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg: DictDefault): # pylint: disable=unused-argument
def pre_model_load(self, cfg: DictDefault):
"""Performs actions before the model is loaded.
Args:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def post_model_build(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions after the model is built/loaded, but before any adapters are applied.
@@ -119,7 +118,6 @@ class BasePlugin:
cfg: The configuration for the plugin.
"""
# pylint: disable=unused-argument
def pre_lora_load(self, cfg: DictDefault, model: PreTrainedModel):
"""Performs actions before LoRA weights are loaded.
@@ -128,7 +126,6 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_lora_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after LoRA weights are loaded.
@@ -137,7 +134,6 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after the model is loaded.
@@ -146,7 +142,6 @@ class BasePlugin:
model: The loaded model.
"""
# pylint: disable=unused-argument
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
"""Returns a custom class for the trainer.
@@ -157,7 +152,6 @@ class BasePlugin:
The first non-`None` trainer class returned by a plugin.
"""
# pylint: disable=unused-argument
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Performs actions after the trainer is created.
@@ -166,7 +160,7 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
def get_training_args(self, cfg: DictDefault):
"""
Returns custom training arguments to set on TrainingArgs.
@@ -177,9 +171,7 @@ class BasePlugin:
object: dict containing the training arguments.
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
def get_collator_cls_and_kwargs(self, cfg: DictDefault, is_eval: bool = False):
"""
Returns a custom class for the collator.
@@ -191,7 +183,6 @@ class BasePlugin:
class: The class for the collator.
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -203,7 +194,6 @@ class BasePlugin:
The created optimizer.
"""
# pylint: disable=unused-argument
def create_lr_scheduler(
self,
cfg: DictDefault,
@@ -223,7 +213,6 @@ class BasePlugin:
The created learning rate scheduler.
"""
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> list[Callable]:
@@ -238,7 +227,6 @@ class BasePlugin:
"""
return []
# pylint: disable=unused-argument
def add_callbacks_post_trainer(
self, cfg: DictDefault, trainer: Trainer
) -> list[Callable]:
@@ -254,7 +242,6 @@ class BasePlugin:
"""
return []
# pylint: disable=unused-argument
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
"""Performs actions after training is complete.
@@ -263,7 +250,7 @@ class BasePlugin:
model: The loaded model.
"""
def post_train_unload(self, cfg: DictDefault): # pylint: disable=unused-argument
def post_train_unload(self, cfg: DictDefault):
"""Performs actions after training is complete and the model is unloaded.
Args:
@@ -311,7 +298,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager: # pylint: disable=too-many-public-methods
class PluginManager:
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.

View File

@@ -50,15 +50,9 @@ def merge_input_args():
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, globals(), namespace
)
AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
"AxolotlInputConfig"
]
AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
"AxolotlConfigWCapabilities"
]
exec(dynamic_input, globals(), namespace) # nosec B102
AxolotlInputConfig = namespace["AxolotlInputConfig"]
AxolotlConfigWCapabilities = namespace["AxolotlConfigWCapabilities"]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
@@ -74,7 +68,7 @@ def merge_training_args() -> Type:
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
@@ -93,11 +87,7 @@ def merge_training_args() -> Type:
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
exec(dynamic_input, {**globals(), **local_vars}, namespace) # nosec B102
AxolotlTrainingMixins = namespace["AxolotlTrainingMixins"]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
```
## Usage

View File

@@ -18,6 +18,7 @@ Module for the Plugin for Cut Cross Entropy integration with Axolotl.
Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
from functools import partial
@@ -28,13 +29,13 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
from .args import CutCrossEntropyArgs as CutCrossEntropyArgs
LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
)
@@ -106,9 +107,7 @@ class CutCrossEntropyPlugin(BasePlugin):
"""
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model, patch_options, model_type: str
): # pylint: disable=unused-argument
def patch_generic(maybe_model, patch_options, model_type: str):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -121,12 +120,10 @@ class CutCrossEntropyPlugin(BasePlugin):
)
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
cut_cross_entropy.transformers.llama._PATCH_OPTS = ( # pylint: disable=protected-access
patch_options
)
cut_cross_entropy.transformers.llama._PATCH_OPTS = patch_options
model_cls.forward = cce_forward
# pylint: disable=duplicate-code
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import ForCausalLM class for model_type: {model_type}. "

View File

@@ -15,6 +15,7 @@
"""
Module for handling Cut Cross Entropy input arguments.
"""
from typing import Optional
from pydantic import BaseModel, model_validator

View File

@@ -7,7 +7,7 @@ from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .args import GrokfastArgs as GrokfastArgs
from .optimizer import gradfilter_ema
LOG = get_logger(__name__)
@@ -24,12 +24,10 @@ class GrokfastCallbackHandler(TrainerCallback):
self.alpha = alpha
self.lamb = lamb
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
def on_train_begin(self, *args_, **kwargs):
self.grads = None
def on_pre_optimizer_step(
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
def on_pre_optimizer_step(self, args_, state, control, **kwargs):
model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control

View File

@@ -1,7 +1,6 @@
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
# Reference: https://github.com/ironjr/grokfast
# pylint: skip-file
from collections import deque
from typing import Dict, Literal, Optional

View File

@@ -15,6 +15,7 @@
"""
Plugin init to add KD support to Axolotl.
"""
from typing import Any
from transformers import Trainer
@@ -22,7 +23,7 @@ from transformers import Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
from .args import KDArgs as KDArgs
class KDPlugin(BasePlugin):

View File

@@ -15,6 +15,7 @@
"""
Plugin args for KD support.
"""
from dataclasses import dataclass
from enum import Enum
@@ -26,8 +27,8 @@ class InferenceServerType(str, Enum):
Online inferences server types to handle different request args
"""
vllm = "vllm" # pylint: disable=invalid-name
sglang = "sglang" # pylint: disable=invalid-name
vllm = "vllm"
sglang = "sglang"
class KDArgs(BaseModel):

Some files were not shown because too many files have changed in this diff Show More