Compare commits

..

39 Commits

Author SHA1 Message Date
Wing Lian
5b56cc18d5 remove fastapi and pydantic extras 2025-01-31 11:17:42 -05:00
Wing Lian
5c3ac90669 chore: lint 2025-01-30 17:33:42 -05:00
Wing Lian
353ba4e80b set MODAL_IMAGE_BUILDER_VERSION=2024.10 to 2024.10 to test latest builder 2025-01-30 16:58:35 -05:00
Dan Saunders
6f294c3d8d refactor README; hardcode links to quarto docs; add additional quarto doc pages (#2295)
* refactor README; hardcode links to quarto docs; add additional quarto doc pages

* updates

* review comments

* update

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-01-30 12:49:21 -05:00
Wing Lian
6f713226dd make save_safetensors: true the default (#2292)
* make save_safetensors: true the default

* revert change to model output check
2025-01-30 11:48:48 -05:00
Wing Lian
1063d82b51 match the cuda version for 2.4.1 build w/o tmux (#2299) 2025-01-30 11:46:09 -05:00
salman
ac471a697a updating to fused (#2293) 2025-01-30 11:45:56 -05:00
Wing Lian
8779997ba5 native support for modal cloud from CLI (#2237)
* native support for modal cloud from CLI

* do lm_eval in cloud too

* Fix the sub call to lm-eval

* lm_eval option to not post eval, and append not extend

* cache bust when using branch, grab sha of latest image tag, update lm-eval dep

* allow minimal yaml for lm eval

* include modal in requirements

* update link in README to include utm

* pr feedback

* use chat template

* revision support

* apply chat template as arg

* add wandb name support, allow explicit a100-40gb

* cloud is optional

* handle accidental setting of tasks with a single task str

* document the modal cloud yaml for clarity [skip ci]

* cli docs

* support spawn vs remote for lm-eval

* Add support for additional docker commands in modal image build

* cloud config shouldn't be a dir

* Update README.md

Co-authored-by: Charles Frye <cfrye59@gmail.com>

* fix annotation args

---------

Co-authored-by: Charles Frye <cfrye59@gmail.com>
2025-01-30 11:34:02 -05:00
Eric Tang
268543a3be Ray Train Axolotl Integration (#2251)
* current

not clean working version
move torch trainer to do_cli
update code with config changes and clean up
edit config
cleanup
add run name to trainer

* address comments

* use axolotl train in multigpu tests and add ray tests for multi-gpu

* accelerate uses underscores for main_process_port arg

* chore: lint

* fix order of accelerate args

* include ray train in docker images

* current

not clean working version
move torch trainer to do_cli
update code with config changes and clean up
edit config
cleanup
add run name to trainer

* address comments

* use axolotl train in multigpu tests and add ray tests for multi-gpu

* accelerate uses underscores for main_process_port arg

* chore: lint

* fix order of accelerate args

* include ray train in docker images

* fix bf16 resolution behavior

* move dtype logic

* x

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* rename

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* add to sidebar

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* Apply suggestions from code review

Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com>

* Update docs/ray-integration.qmd

Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com>

* pre-commit fixes

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>

* use output_dir instead of hardcoded saves path

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* bugfix storage dir

* change type\ for resources_per_worker

---------

Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: SumanthRH <sumanthrh@anyscale.com>
Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2025-01-29 00:10:19 -05:00
salman
54dd7abfc1 Process reward models (#2241)
* adding model_cfg to set num_labels

* using a num_labels field instead

* linting

* WIP stepwise prompt tokenizer

* this should work?

* trainer working?

* pushing to runpod

* fixing saving

* updating conf

* updating config, adding docs

* adding stepwise supervision docpage

* updating tests

* adding test for dataset

* fixing tests

* linting

* addressing some comments

* adding additional cfg fields support

* updating tests, fixing cfg

* fixing tests

* updating loss

* Update test_process_reward_model_smollm2.py

* updating loss values and seed

* dumb pre-commit
2025-01-29 00:08:33 -05:00
salman
c071a530f7 removing 2.3.1 (#2294) 2025-01-28 23:23:44 -05:00
mashdragon
c015a76a23 Num epochs float (#2282) [skip ci]
* Change num_epochs type to float

* Handle float value for num_epochs in trainer.py
2025-01-28 23:23:26 -05:00
NanoCode012
067b442596 chore: refactor SaveModelCallback to stop handle fractional save_steps (#2291) [skip ci] 2025-01-28 23:22:10 -05:00
Wing Lian
0b52f06227 bump bnb to 0.45.1 (#2289) [skip ci] 2025-01-28 23:21:25 -05:00
Wing Lian
887513285d support for custom lr groups for non-embedding modules (#2213)
* support for custom lr groups for non-embedding modules

invert name check for group modules
include lr_groups in training args
additional conditional for creating optimizer
fix regular params as w weight decay
fix lookup and add docs

* address pr feedback
2025-01-24 12:56:28 -05:00
Wing Lian
20620771f1 Pretrain multipack (#2278)
* fix for pretrain with packing

* fix model name and loss expected

* make sure to check with micro batch size for pretraining

* change loss threshholds based on parametrization

* make tests smaller for CI

* fix pretrain packing

* fix pretrain packing test

* address pr feedback
2025-01-24 12:55:20 -05:00
NanoCode012
6086162488 chore(doc): improve explanation for *_steps and *_strategy (#2270) 2025-01-24 10:07:02 -05:00
mashdragon
b2774af66c Take split param from config in all load_dataset instances (#2281) 2025-01-24 10:06:50 -05:00
NanoCode012
74f9782fc3 chore(doc): fix explanation on gcs creds retrieval (#2272) 2025-01-24 10:05:58 -05:00
Wing Lian
8a7a0b07dc support for latest transformers release 4.48.1 (#2256) 2025-01-23 21:17:57 -05:00
Wing Lian
8fb72cbc0b use the extracted field_messages to parse the role fields (#2265) 2025-01-21 15:39:30 -05:00
Adithya Kamath
bb9d4102c4 Add 5000 line history limit to tmux for docker cloud (#2268) 2025-01-21 15:39:17 -05:00
Wing Lian
af727eedf7 option to not concatenate during pretraining (#2263)
* option to not concatenate during pretraining

* simplify conditional and add doc to config.qmd
2025-01-20 14:07:34 -05:00
jwongTensora
8606093921 fix for indexing error from token/embeddings mismatch (#2257)
Co-authored-by: jwong <jwongTensora@gmail.com>
2025-01-14 22:09:29 -05:00
NanoCode012
cba5a457d9 fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-01-14 22:08:56 -05:00
Wing Lian
19cd83d408 rename references to dpo dataset prep to pref data (#2258) 2025-01-14 22:07:55 -05:00
Dan Saunders
1ed4de73b6 CLI cleanup and documentation (#2244)
* CLI init refactor

* fix

* cleanup and (partial) docs

* Adding documentation and continuing cleanup (in progress)

* remove finetune.py script

* continued cleanup and documentation

* pytest fixes

* review comments

* fix

* Fix

* typing fixes

* make sure the batch dataset patcher for multipack is always loaded when handling datasets

* review comments

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-01-13 17:55:29 +00:00
Wing Lian
f89e962119 skip over rows in pretraining dataset (#2223)
* skip over rows in pretraining dataset

* update docs
2025-01-13 10:44:45 -05:00
Wing Lian
bc1c9c20e3 assume empty lora dropout means 0.0 and add tests (#2243)
* assume empty lora dropout means 0.0 and add tests

* remove un-necessary arg

* refactor based on pr feedback:

* chore: lint
2025-01-13 10:44:11 -05:00
Wing Lian
dd26cc3c0f add helper to verify the correct model output file exists (#2245)
* add helper to verify the correct model output file exists

* more checks using helper

* chore: lint

* fix import and relora model check

* workaround for trl trainer saves

* remove stray print
2025-01-13 10:43:29 -05:00
Wing Lian
d8b4027200 use 2.5.1 docker images as latest tag as it seems stable (#2198) 2025-01-10 08:35:25 -05:00
Wing Lian
fb3352e21c rename liger test so it properly runs in ci (#2246) 2025-01-09 17:31:43 -05:00
NanoCode012
ed77e7001e feat: add support for data_files in pretraining (#2238) 2025-01-09 21:04:13 +00:00
Wing Lian
7669a03fb4 update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
2025-01-09 21:01:59 +00:00
Vincenzo di Cicco
6553683170 Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235) 2025-01-09 21:01:22 +00:00
Wing Lian
5e0124e2ab update modal version for ci (#2242) 2025-01-09 21:01:02 +00:00
NanoCode012
2e8d7c1adb fix: mistral nemo does not recognize token_type_ids in forward (#2233) 2025-01-09 21:00:36 +00:00
Wing Lian
3c1921e400 add hf cache caching for GHA (#2247)
* add hf cache caching for GHA

* use modal volume to cache hf data

* make sure to update the cache as we add new fixtures in conftest
2025-01-09 20:59:54 +00:00
Wing Lian
7faf2b6e8e Merge group queue (#2248)
* add support for merge groups

* also lint merge groups
2025-01-09 15:49:00 -05:00
161 changed files with 4536 additions and 3097 deletions

View File

@@ -15,7 +15,7 @@ First of all, thank you for your interest in contributing to axolotl! We appreci
- [Commit Messages](#commit-messages)
- [Additional Resources](#additional-resources)
## Code of Conductcode
## Code of Conduct
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.

View File

@@ -22,18 +22,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.10"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""

View File

@@ -1,6 +1,7 @@
name: lint
on:
# check on PRs, and manual triggers
merge_group:
pull_request:
paths:
- '**.py'

View File

@@ -15,17 +15,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -36,6 +25,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -82,17 +72,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -103,6 +82,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -145,10 +125,10 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.3.1
pytorch: 2.4.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -20,12 +20,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -52,7 +46,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -12,17 +12,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -76,17 +65,6 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
@@ -98,13 +98,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -129,7 +122,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,6 +1,7 @@
name: Tests
on:
# check on push/merge to main, PRs, and manual triggers
merge_group:
push:
branches:
- "main"
@@ -48,7 +49,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
@@ -60,6 +61,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -100,6 +110,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -115,6 +134,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -156,6 +184,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
@@ -183,7 +220,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -191,6 +228,7 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
@@ -207,12 +245,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -229,7 +261,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -237,6 +269,7 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |

View File

@@ -19,7 +19,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint

775
README.md
View File

@@ -1,8 +1,8 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
@@ -19,235 +19,99 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Axolotl is a tool designed to streamline post-training for various AI models.
Post-training refers to any modifications or additional training performed on
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
techniques. With support for multiple model architectures and training configurations,
Axolotl makes it easy to get started with these techniques.
Axolotl is designed to work with YAML config files that contain everything you need to
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
and much more.
Features:
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
## 🚀 Quick Start
<table>
<tr>
<td>
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
## Table of Contents
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Quickstart ⚡](#quickstart-)
- [Edge Builds](#edge-builds-)
- [Axolotl CLI Usage](#axolotl-cli-usage)
- [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [Axolotl supports](#axolotl-supports)
- [Advanced Setup](#advanced-setup)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu)
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [LambdaLabs](#lambdalabs)
- [GCP](#gcp)
- [Windows](#windows)
- [Mac](#mac)
- [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset)
- [Config](#config)
- [All Config Options](#all-config-options)
- [Train](#train)
- [Preprocess dataset](#preprocess-dataset)
- [Multi-GPU](#multi-gpu)
- [DeepSpeed](#deepspeed)
- [FSDP](#fsdp)
- [FSDP + QLoRA](#fsdp--qlora)
- [Weights \& Biases Logging](#weights--biases-logging)
- [Special Tokens](#special-tokens)
- [Liger Kernel](#liger-kernel)
- [Inference Playground](#inference-playground)
- [Merge LORA to base](#merge-lora-to-base)
- [Common Errors 🧰](#common-errors-)
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need help? 🙋](#need-help-)
### Installation
</td>
<td>
<div align="center">
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
<div>
<p>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p>
<p>
Go ahead and Axolotl questions!!
</p>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
</div>
</div>
</td>
</tr>
</table>
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash
```shell
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# download examples and optionally deepspeed configs to the local path
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
# finetune using lora
axolotl train examples/llama-3/lora-1b.yml
```
### Edge Builds 🏎️
Other installation approaches are described [here](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html).
If you're looking for the latest features and updates between releases, you'll need to install
from source.
### Your First Fine-tune
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Axolotl CLI Usage
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
# finetune lora
axolotl train examples/llama-3/lora-1b.yml
# inference
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out"
# gradio
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
local machine. This will come in handy when installing `axolotl` from PyPI.
```bash
# Fetch example YAML files (stores in "examples/" folder)
```shell
# Fetch axolotl examples
axolotl fetch examples
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
axolotl fetch deepspeed_configs
# Optionally, specify a destination folder
# Or, specify a custom path
axolotl fetch examples --dest path/to/folder
# Train a model using LoRA
axolotl train examples/llama-3/lora-1b.yml
```
### Legacy Usage
<details>
That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/getting-started.html) for a more detailed walkthrough.
<summary>Click to Expand</summary>
## ✨ Key Features
While the Axolotl CLI is the preferred method for interacting with axolotl, we
still support the legacy `-m axolotl.cli.*` usage.
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
- **Easy Configuration**: Simple YAML files to control your training setup
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
- **Flexible Dataset Handling**: Use various formats and custom datasets
- **Cloud Ready**: Run on cloud platforms or local hardware
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
## 📚 Documentation
# finetune lora
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
- [Installation Options](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html) - Full configuration options and examples
- [Dataset Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) - Supported formats and how to use them
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
# inference
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out"
## 🤝 Getting Help
# gradio
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out" --gradio
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory
- Read our [Debugging Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/debugging.html)
- Need dedicated support? Please contact [wing@axolotl.ai](mailto:wing@axolotl.ai) for options
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
## 🌟 Contributing
</details>
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Sponsors 🤝❤
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
---
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
---
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
Please run the quickstart instructions followed by the below to setup env:
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Axolotl supports
## Supported Models
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
@@ -272,523 +136,16 @@ Thanks to all of our contributors to date. Help drive open source AI progress fo
❌: not supported
❓: untested
## Advanced Setup
## ❤️ Sponsors
### Environment
Thank you to our sponsors who help make Axolotl possible:
#### Docker
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
Or run on the current files for development:
## 📜 License
```sh
docker compose up -d
```
>[!Tip]
> If you want to debug axolotl or prefer to use Docker as your development environment, see the [debugging guide's section on Docker](docs/debugging.qmd#debugging-with-docker).
<details>
<summary>Docker advanced</summary>
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
```
It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
* The `--privileged` flag gives all capabilities to the container.
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
</details>
#### Conda/Pip venv
1. Install python >=**3.10**
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install Axolotl along with python dependencies
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Huggingface to use gated models/datasets.
```bash
huggingface-cli login
```
Get the token at huggingface.co/settings/tokens
#### Cloud GPU
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU
##### LambdaLabs
<details>
<summary>Click to Expand</summary>
1. Install python
```bash
sudo apt update
sudo apt install -y python3.10
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
sudo update-alternatives --config python # pick 3.10 if given option
python -V # should be 3.10
```
2. Install pip
```bash
wget https://bootstrap.pypa.io/get-pip.py
python get-pip.py
```
3. Install Pytorch https://pytorch.org/get-started/locally/
4. Follow instructions on quickstart.
5. Run
```bash
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
```
6. Set path
```bash
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
```
</details>
##### GCP
<details>
<summary>Click to Expand</summary>
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
Make sure to run the below to uninstall xla.
```bash
pip uninstall -y torch_xla[tpu]
```
</details>
#### Windows
Please use WSL or Docker!
#### Mac
Use the below instead of the install method in QuickStart.
```
pip3 install --no-build-isolation -e '.'
```
More info: [mac.md](/docs/mac.qmd)
#### Google Colab
Please use this example [notebook](examples/colab-notebooks/colab-axolotl-example.ipynb).
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
# Managed spot (auto-recovery on preemption)
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
#### Launching on public clouds via dstack
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
Write a job description in YAML as below:
```yaml
# dstack.yaml
type: task
image: axolotlai/axolotl-cloud:main-latest
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
```
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
```bash
pip install dstack
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
```
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
### Dataset
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config
See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
- model
```yaml
base_model: ./llama-7b-hf # local or huggingface repo
```
Note: The code will load the right architecture.
- dataset
```yaml
datasets:
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca
# huggingface repo with specific configuration/subset
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
- path: ...
type: chat_template
chat_template: chatml # defaults to tokenizer's chat_template
# local
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
- loading
```yaml
load_in_4bit: true
load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
Note: Repo does not do 4-bit quantization.
- lora
```yaml
adapter: lora # 'qlora' or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
```
#### All Config Options
See [these docs](docs/config.qmd) for all config options.
### Train
Run
```bash
accelerate launch -m axolotl.cli.train your_config.yml
```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset.
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- (Optional): Use `--debug` to see preprocessed examples.
```bash
python -m axolotl.cli.preprocess your_config.yml
```
#### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
is the recommended multi-GPU option currently because FSDP may experience
[loss instability](https://github.com/huggingface/transformers/issues/26498).
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed_configs/zero1.json
```
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
##### FSDP
- llama FSDP
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.qmd) for more information.
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
```yml
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens: # these are delimiters
- "<|im_start|>"
- "<|im_end|>"
```
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
##### Liger Kernel
Liger Kernel: Efficient Triton Kernels for LLM Training
https://github.com/linkedin/Liger-Kernel
Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training.
It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel
composes well and is compatible with both FSDP and Deepspeed.
```yaml
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
### Inference Playground
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
The config file is the same config file used for training.
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
- Pretrained LORA:
```bash
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
```
- Full weights finetune:
```bash
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
```
- Full weights finetune w/ a prompt from a text file:
```bash
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True
```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```
Please use `--sample_packing False` if you have it on and receive the error similar to below:
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
```
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
```
although this will be very slow, and using the config options above are recommended instead.
## Common Errors 🧰
See also the [FAQ's](./docs/faq.qmd) and [debugging guide](docs/debugging.qmd).
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below
- `micro_batch_size`
- `eval_batch_size`
- `gradient_accumulation_steps`
- `sequence_len`
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
Using adamw_bnb_8bit might also save you some memory.
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.
Similarly, you should consider reducing the same settings as when you run out of VRAM.
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
> RuntimeError: expected scalar type Float but found Half
Try set `fp16: true`
> NotImplementedError: No operator found for `memory_efficient_attention_forward` ...
Try to turn off xformers.
> accelerate config missing
It's safe to ignore it.
> NCCL Timeouts during training
See the [NCCL](docs/nccl.qmd) guide.
### Tokenization Mismatch b/w Inference & Training
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html) for a concrete example.
## Debugging Axolotl
See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
## Need help? 🙋
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
Need dedicated support? Please contact us at [wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

View File

@@ -28,16 +28,21 @@ website:
- section: "How-To Guides"
contents:
# TODO Edit folder structure after we have more docs.
- docs/getting-started.qmd
- docs/installation.qmd
- docs/debugging.qmd
- docs/inference.qmd
- docs/multipack.qmd
- docs/fsdp_qlora.qmd
- docs/input_output.qmd
- docs/rlhf.qmd
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- docs/amd_hpc.qmd
- docs/ray-integration.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
@@ -45,7 +50,6 @@ website:
- docs/config.qmd
- docs/faq.qmd
format:
html:
theme: materia

View File

@@ -8,6 +8,7 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
@@ -31,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -6,5 +6,6 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -23,11 +23,12 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -48,6 +49,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -67,6 +74,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -23,12 +23,13 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -37,19 +38,21 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -69,6 +72,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -20,7 +20,8 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

256
docs/cli.qmd Normal file
View File

@@ -0,0 +1,256 @@
# Axolotl CLI Documentation
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
the CLI commands, their usage, and common examples.
### Table of Contents
- Basic Commands
- Command Reference
- fetch
- preprocess
- train
- inference
- merge-lora
- merge-sharded-fsdp-weights
- evaluate
- lm-eval
- Legacy CLI Usage
- Remote Compute with Modal Cloud
- Cloud Configuration
- Running on Modal Cloud
- Cloud Configuration Options
### Basic Commands
All Axolotl commands follow this general structure:
```bash
axolotl <command> [config.yml] [options]
```
The config file can be local or a URL to a raw YAML file.
### Command Reference
#### fetch
Downloads example configurations and deepspeed configs to your local machine.
```bash
# Get example YAML files
axolotl fetch examples
# Get deepspeed config files
axolotl fetch deepspeed_configs
# Specify custom destination
axolotl fetch examples --dest path/to/folder
```
#### preprocess
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
```bash
# Basic preprocessing
axolotl preprocess config.yml
# Preprocessing with one GPU
CUDA_VISIBLE_DEVICES="0" axolotl preprocess config.yml
# Debug mode to see processed examples
axolotl preprocess config.yml --debug
# Debug with limited examples
axolotl preprocess config.yml --debug --debug-num-examples 5
```
Configuration options:
```yaml
dataset_prepared_path: Local folder for saving preprocessed data
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
```
#### train
Trains or fine-tunes a model using the configuration specified in your YAML file.
```bash
# Basic training
axolotl train config.yml
# Train and set/override specific options
axolotl train config.yml \
--learning-rate 1e-4 \
--micro-batch-size 2 \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --no-accelerate
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
```
#### inference
Runs inference using your trained model in either CLI or Gradio interface mode.
```bash
# CLI inference with LoRA
axolotl inference config.yml --lora-model-dir="./outputs/lora-out"
# CLI inference with full model
axolotl inference config.yml --base-model="./completed-model"
# Gradio web interface
axolotl inference config.yml --gradio \
--lora-model-dir="./outputs/lora-out"
# Inference with input from file
cat prompt.txt | axolotl inference config.yml \
--base-model="./completed-model"
```
#### merge-lora
Merges trained LoRA adapters into the base model.
```bash
# Basic merge
axolotl merge-lora config.yml
# Specify LoRA directory (usually used with checkpoints)
axolotl merge-lora config.yml --lora-model-dir="./lora-output/checkpoint-100"
# Merge using CPU (if out of GPU memory)
CUDA_VISIBLE_DEVICES="" axolotl merge-lora config.yml
```
Configuration options:
```yaml
gpu_memory_limit: Limit GPU memory usage
lora_on_cpu: Load LoRA weights on CPU
```
#### merge-sharded-fsdp-weights
Merges sharded FSDP model checkpoints into a single combined checkpoint.
```bash
# Basic merge
axolotl merge-sharded-fsdp-weights config.yml
```
#### evaluate
Evaluates a model's performance using metrics specified in the config.
```bash
# Basic evaluation
axolotl evaluate config.yml
```
#### lm-eval
Runs LM Evaluation Harness on your model.
```bash
# Basic evaluation
axolotl lm-eval config.yml
# Evaluate specific tasks
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
```
Configuration options:
```yaml
lm_eval_tasks: List of tasks to evaluate
lm_eval_batch_size: Batch size for evaluation
output_dir: Directory to save evaluation results
```
### Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
```bash
# Preprocess
python -m axolotl.cli.preprocess config.yml
# Train
accelerate launch -m axolotl.cli.train config.yml
# Inference
accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out"
# Gradio interface
accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out" --gradio
```
### Remote Compute with Modal Cloud
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
cloud YAML file alongside your regular Axolotl config.
#### Cloud Configuration
Create a cloud config YAML with your Modal settings:
```yaml
# cloud_config.yml
provider: modal
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu_count: 1 # Number of GPUs to use
timeout: 86400 # Maximum runtime in seconds (24 hours)
branch: main # Git branch to use (optional)
volumes: # Persistent storage volumes
- name: axolotl-cache
mount: /workspace/cache
env: # Environment variables
- WANDB_API_KEY
- HF_TOKEN
```
#### Running on Modal Cloud
Commands that support the --cloud flag:
```bash
# Preprocess on cloud
axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```
#### Cloud Configuration Options
```yaml
provider: compute provider, currently only `modal` is supported
gpu: GPU type to use
gpu_count: Number of GPUs (default: 1)
memory: RAM in GB (default: 128)
timeout: Maximum runtime in seconds
timeout_preprocess: Preprocessing timeout
branch: Git branch to use
docker_tag: Custom Docker image tag
volumes: List of persistent storage volumes
env: Environment variables to pass
secrets: Secrets to inject
```

View File

@@ -187,6 +187,12 @@ rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# reward modelling: `True` or `False`
reward_model:
# process reward modelling: `True` or `False`
process_reward_model:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
@@ -244,6 +250,8 @@ total_num_tokens:
sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing
batch_flattening:
@@ -358,10 +366,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
save_strategy: # Set to `"no"` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`.
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`.
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that

View File

@@ -8,14 +8,12 @@ order: 3
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
## pygmalion
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.

View File

@@ -19,7 +19,14 @@ For pretraining, there is no prompt template or roles. The only required field
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```{.yaml filename="config.yaml"}
pretraining_dataset: # hf path only
pretraining_dataset:
- name:
path:
split:
text_column: # column in dataset with the data, usually `text`
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
...
```

View File

@@ -0,0 +1,26 @@
---
title: Stepwise Supervised Format
description: Format for datasets with stepwise completions and labels
order: 3
---
## Stepwise Supervised
The stepwise supervised format is designed for chain-of-thought (COT) reasoning
datasets where each example contains multiple completion steps and a preference label
for each step.
### Example
Here's a simple example of a stepwise supervised dataset entry:
```json
{
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": [
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.",
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."
],
"labels": [true, false]
}
```

155
docs/getting-started.qmd Normal file
View File

@@ -0,0 +1,155 @@
---
title: "Getting Started with Axolotl"
format:
html:
toc: true
toc-depth: 3
number-sections: true
execute:
enabled: false
---
This guide will walk you through your first model fine-tuning project with Axolotl.
## Quick Example {#sec-quick-example}
Let's start by fine-tuning a small language model using LoRA. This example uses a 1B parameter model to ensure it runs on most GPUs.
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
1. Download example configs:
```shell
axolotl fetch examples
```
2. Run the training:
```shell
axolotl train examples/llama-3/lora-1b.yml
```
That's it! Let's understand what just happened.
## Understanding the Process {#sec-understanding}
### The Configuration File {#sec-config}
The YAML configuration file controls everything about your training. Here's what (part of) our example config looks like:
```yaml
base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
```
See our [Config options](config.qmd) for more details.
### Training {#sec-training}
When you run `axolotl train`, Axolotl:
1. Downloads the base model
2. (If specified) applies LoRA adapter layers
3. Loads and processes the dataset
4. Runs the training loop
5. Saves the trained model and / or LoRA weights
## Your First Custom Training {#sec-custom}
Let's modify the example for your own data:
1. Create a new config file `my_training.yml`:
```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1
adapter: lora
# Training settings
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.0003
# Your dataset
datasets:
- path: my_data.jsonl # Your local data file
type: alpaca # Or other format
```
This specific config is for LoRA fine-tuning a model with instruction tuning data using
the `alpaca` dataset format, which has the following format:
```json
{
"instruction": "Write a description of alpacas.",
"input": "",
"output": "Alpacas are domesticated South American camelids..."
}
```
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
format):
```json
{"instruction": "Classify this text", "input": "I love this!", "output": "positive"}
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training:
```shell
axolotl train my_training.yml
```
## Common Tasks {#sec-common-tasks}
### Testing Your Model {#sec-testing}
After training, test your model:
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first:
```shell
axolotl preprocess my_training.yml
```
### Using a UI {#sec-ui}
Launch a Gradio interface:
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to:
- Try different model architectures
- Experiment with hyperparameters
- Use more advanced training methods
- Scale up to larger models
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

148
docs/inference.qmd Normal file
View File

@@ -0,0 +1,148 @@
---
title: "Inference Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
## Quick Start {#sec-quickstart}
### Basic Inference {#sec-basic}
::: {.panel-tabset}
## LoRA Models
```{.bash}
axolotl inference your_config.yml --lora-model-dir="./lora-output-dir"
```
## Full Fine-tuned Models
```{.bash}
axolotl inference your_config.yml --base-model="./completed-model"
```
:::
## Advanced Usage {#sec-advanced}
### Gradio Interface {#sec-gradio}
Launch an interactive web interface:
```{.bash}
axolotl inference your_config.yml --gradio
```
### File-based Prompts {#sec-file-prompts}
Process prompts from a text file:
```{.bash}
cat /tmp/prompt.txt | axolotl inference your_config.yml \
--base-model="./completed-model" --prompter=None
```
### Memory Optimization {#sec-memory}
For large models or limited memory:
```{.bash}
axolotl inference your_config.yml --load-in-8bit=True
```
## Merging LoRA Weights {#sec-merging}
Merge LoRA adapters with the base model:
```{.bash}
axolotl merge-lora your_config.yml --lora-model-dir="./completed-model"
```
### Memory Management for Merging {#sec-memory-management}
::: {.panel-tabset}
## Configuration Options
```{.yaml}
gpu_memory_limit: 20GiB # Adjust based on your GPU
lora_on_cpu: true # Process on CPU if needed
```
## Force CPU Merging
```{.bash}
CUDA_VISIBLE_DEVICES="" axolotl merge-lora ...
```
:::
## Tokenization {#sec-tokenization}
### Common Issues {#sec-tokenization-issues}
::: {.callout-warning}
Tokenization mismatches between training and inference are a common source of problems.
:::
To debug:
1. Check training tokenization:
```{.bash}
axolotl preprocess your_config.yml --debug
```
2. Verify inference tokenization by decoding tokens before model input
3. Compare token IDs between training and inference
### Special Tokens {#sec-special-tokens}
Configure special tokens in your YAML:
```{.yaml}
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens:
- "<|im_start|>"
- "<|im_end|>"
```
## Troubleshooting {#sec-troubleshooting}
### Common Problems {#sec-common-problems}
::: {.panel-tabset}
## Memory Issues
- Use 8-bit loading
- Reduce batch sizes
- Try CPU offloading
## Token Issues
- Verify special tokens
- Check tokenizer settings
- Compare training and inference preprocessing
## Performance Issues
- Verify model loading
- Check prompt formatting
- Ensure temperature/sampling settings
:::
For more details, see our [debugging guide](debugging.qmd).

119
docs/installation.qmd Normal file
View File

@@ -0,0 +1,119 @@
---
title: "Installation Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers all the ways you can install and set up Axolotl for your environment.
## Requirements {#sec-requirements}
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
## Installation Methods {#sec-installation-methods}
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### Edge/Development Build {#sec-edge-build}
For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Docker {#sec-docker}
```{.bash}
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
For development with Docker:
```{.bash}
docker compose up -d
```
::: {.callout-tip}
### Advanced Docker Configuration
```{.bash}
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
--name axolotl --ipc=host \
--ulimit memlock=-1 --ulimit stack=67108864 \
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
axolotlai/axolotl:main-latest
```
:::
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}
For providers supporting Docker:
- Use `axolotlai/axolotl-cloud:main-latest`
- Available on:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
### Google Colab {#sec-colab}
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
## Platform-Specific Instructions {#sec-platform-specific}
### macOS {#sec-macos}
```{.bash}
pip3 install --no-build-isolation -e '.'
```
See @sec-troubleshooting for Mac-specific issues.
### Windows {#sec-windows}
::: {.callout-important}
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
:::
## Environment Managers {#sec-env-managers}
### Conda/Pip venv {#sec-conda}
1. Install Python ≥3.10
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:
```{.bash}
huggingface-cli login
```
## Troubleshooting {#sec-troubleshooting}
If you encounter installation issues, see our [FAQ](faq.qmd) and [Debugging Guide](debugging.qmd).

29
docs/lr_groups.qmd Normal file
View File

@@ -0,0 +1,29 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

118
docs/multi-gpu.qmd Normal file
View File

@@ -0,0 +1,118 @@
---
title: "Multi-GPU Training Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers advanced training configurations for multi-GPU setups using Axolotl.
## Overview {#sec-overview}
Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
### Configuration {#sec-deepspeed-config}
Add to your YAML config:
```{.yaml}
deepspeed: deepspeed_configs/zero1.json
```
### Usage {#sec-deepspeed-usage}
```{.bash}
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
### ZeRO Stages {#sec-zero-stages}
We provide default configurations for:
- ZeRO Stage 1 (`zero1.json`)
- ZeRO Stage 2 (`zero2.json`)
- ZeRO Stage 3 (`zero3.json`)
Choose based on your memory requirements and performance needs.
## FSDP {#sec-fsdp}
### Basic FSDP Configuration {#sec-fsdp-config}
```{.yaml}
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
## Performance Optimization {#sec-performance}
### Liger Kernel Integration {#sec-liger}
::: {.callout-note}
Liger Kernel provides efficient Triton kernels for LLM training, offering:
- 20% increase in multi-GPU training throughput
- 60% reduction in memory usage
- Compatibility with both FSDP and DeepSpeed
:::
Configuration:
```{.yaml}
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Troubleshooting {#sec-troubleshooting}
### NCCL Issues {#sec-nccl}
For NCCL-related problems, see our [NCCL troubleshooting guide](nccl.qmd).
### Common Problems {#sec-common-problems}
::: {.panel-tabset}
## Memory Issues
- Reduce `micro_batch_size`
- Reduce `eval_batch_size`
- Adjust `gradient_accumulation_steps`
- Consider using a higher ZeRO stage
## Training Instability
- Start with DeepSpeed ZeRO-2
- Monitor loss values
- Check learning rates
:::
For more detailed troubleshooting, see our [debugging guide](debugging.qmd).

93
docs/ray-integration.qmd Normal file
View File

@@ -0,0 +1,93 @@
---
title: Ray Train integration
description: How to use Axolotl with Ray Train
---
Axolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.
With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training.
## Ray cluster setup
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
## Sanity check
To run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:
```bash
ray status
```
The output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:
```
Node status
---------------------------------------------------------------
Active:
1 head
Idle:
2 4xL40S:48CPU-384GB
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/96.0 CPU
0.0/8.0 GPU
0B/800.00GiB memory
0B/229.57GiB object_store_memory
Demands:
(no resource demands)
```
You should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html).
## Configuring training with Ray Train
You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
The key parameters to note here are:
```yaml
...
use_ray: true
ray_num_workers: 4
# optional
resources_per_worker:
GPU: 1
...
```
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
- `ray_num_workers`: This is the number of workers/GPUs to use for training.
- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do
```yaml
resources_per_worker:
accelerator_type:L40S: 0.001
```
## Launching training
You can simply run the following command on the head node:
```bash
axolotl train examples/llama-3/lora-1b-ray.yml --use-ray
```
This will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.
You can also monitor training progress on the Ray dashboard.
Coming back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following:
![Ray dashboard](./images/ray-cluster-dashboard.png)

47
docs/reward_modelling.qmd Normal file
View File

@@ -0,0 +1,47 @@
---
title: "Reward Modelling"
description: "Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. "
---
### Overview
Reward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions.
We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
```yaml
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.1
eval_steps: 100
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B
model_type: AutoModelForTokenClassification
num_labels: 2
process_reward_model: true
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
split: train
val_set_size: 0.1
eval_steps: 100
```

View File

@@ -46,7 +46,7 @@ output_dir: ./outputs/btlm-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
optimizer: adamw_torch_fused
adam_beta2: 0.95
adam_eps: 0.000000001
max_grad_norm: 1.0

28
examples/cloud/modal.yaml Normal file
View File

@@ -0,0 +1,28 @@
project_name:
volumes:
- name: axolotl-data
mount: /workspace/data
- name: axolotl-artifacts
mount: /workspace/artifacts
# environment variables from local to set as secrets
secrets:
- HF_TOKEN
- WANDB_API_KEY
# Which branch of axolotl to use remotely
branch:
# additional custom commands when building the image
dockerfile_commands:
gpu: h100
gpu_count: 1
# Train specific configurations
memory: 128
timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400

View File

@@ -27,7 +27,7 @@ wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -47,7 +47,7 @@ peft_use_rslora: true
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -1,6 +1,7 @@
base_model: google/gemma-2-2b
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -34,7 +34,7 @@ lora_target_linear: false
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -42,7 +42,7 @@ output_dir: ./outputs/model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -37,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -0,0 +1,79 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero3.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"
use_ray: true
ray_num_workers: 4

View File

@@ -30,7 +30,7 @@ lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -47,7 +47,7 @@ wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -41,7 +41,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -43,7 +43,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 12
num_epochs: 2
optimizer: adamw_torch
optimizer: adamw_torch_fused
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
optimizer: adamw_torch_fused
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

72
examples/qwen2/prm.yaml Normal file
View File

@@ -0,0 +1,72 @@
base_model: Qwen/Qwen2.5-3B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForTokenClassification
num_labels: 2
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
process_reward_model: true
chat_template:
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
step_separator: "\n"
max_completion_length:
train_on_last_step_only: false
val_set_size: 0.2
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
eval_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
eval_steps: 100
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -37,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -0,0 +1,67 @@
base_model: Qwen/Qwen2.5-0.5B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: qwen_25
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -1,8 +1,8 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
bitsandbytes==0.45.1
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
@@ -13,18 +13,19 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.47.1
tokenizers>=0.20.1
accelerate==1.2.1
datasets==3.1.0
transformers==4.48.1
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.12.1
trl==0.13.0
optimum==1.16.2
hf_transfer
sentencepiece
gradio==3.50.2
modal==0.70.5
pydantic==2.6.3
addict
fire
@@ -53,7 +54,7 @@ zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
lm_eval==0.4.7
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
@@ -61,4 +62,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.2
axolotl-contribs-lgpl==0.0.3

View File

@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
)
ds_cfg["field_messages"] = field_messages
message_fields = features["conversations"][0].keys()
message_fields = features[field_messages][0].keys()
message_field_role = None
for key in ["from", "role"]:
if key in message_fields:

View File

@@ -1,52 +0,0 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
do_inference,
do_merge_lora,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.cli.shard import shard
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
LOG = logging.getLogger("axolotl.scripts.finetune")
def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
LOG.warning(
str(
PendingDeprecationWarning(
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
)
)
)
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,10 +1,15 @@
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:

View File

@@ -32,7 +32,6 @@ def parse_requirements():
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
@@ -86,22 +85,8 @@ def parse_requirements():
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
raise ValueError("axolotl requires torch>=2.4")
except PackageNotFoundError:
pass
@@ -165,5 +150,8 @@ setup(
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
},
)

View File

@@ -1,568 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
"""Axolotl CLI module initialization."""
import importlib
import json
import logging
import math
import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import requests
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
prepare_plugins,
validate_config,
)
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(ascii_text, font=font)
if is_main_process():
print(ascii_art)
print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload(progressbar=True)
try:
model.to(dtype=cfg.torch_dtype)
except RuntimeError:
pass
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token():
# Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info(
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
)
return True
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

45
src/axolotl/cli/args.py Normal file
View File

@@ -0,0 +1,45 @@
"""Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class PreprocessCliArgs:
"""Dataclass with CLI arguments for `axolotl preprocess` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
@dataclass
class TrainerCliArgs:
"""Dataclass with CLI arguments for `axolotl train` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
@dataclass
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None)

23
src/axolotl/cli/art.py Normal file
View File

@@ -0,0 +1,23 @@
"""Axolotl ASCII logo utils."""
from axolotl.utils.distributed import is_main_process
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
if is_main_process():
print(AXOLOTL_LOGO)

50
src/axolotl/cli/checks.py Normal file
View File

@@ -0,0 +1,50 @@
"""Various checks for Axolotl CLI."""
import logging
import os
from pathlib import Path
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__)
def check_accelerate_default_config() -> None:
"""Logs at warning level if no accelerate config file is found."""
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token() -> bool:
"""Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.
Returns:
Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).
Raises:
LocalTokenNotFoundError: If HF user info can't be retrieved.
"""
# Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info(
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
)
return True
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

View File

@@ -0,0 +1,56 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
import yaml
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
return cloud_cfg
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.preprocess(config_yaml)
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -0,0 +1,18 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
class Cloud(ABC):
"""
Abstract base class for cloud platforms.
"""
@abstractmethod
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
pass

View File

@@ -0,0 +1,282 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
import modal
from axolotl.cli.cloud.base import Cloud
def run_cmd(cmd: str, run_folder: str, volumes=None):
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
# Ensure volumes contain latest files.
if volumes:
for _, vol in volumes.items():
vol.reload()
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code) # pylint: disable=consider-using-sys-exit
# Commit writes to volume.
if volumes:
for _, vol in volumes.items():
vol.commit()
class ModalCloud(Cloud):
"""
Modal Cloud implementation.
"""
def __init__(self, config, app=None):
self.config = config
if not app:
app = modal.App()
self.app = app
self.volumes = {}
if config.volumes:
for volume_config in config.volumes:
_, mount, vol = self.create_volume(volume_config)
self.volumes[mount] = (vol, volume_config)
def get_env(self):
res = {
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
}
for key in self.config.get("env", []):
if isinstance(key, str):
if val := os.environ.get(key, ""):
res[key] = val
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res[key_] = val
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:
sha256_hash = None
# create the image
if sha256_hash:
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
else:
image = modal.Image.from_registry(docker_image)
dockerfile_commands = []
if self.config.dockerfile_commands:
dockerfile_commands.extend(self.config.dockerfile_commands)
# branch
if self.config.branch:
dockerfile_commands.extend(
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
if dockerfile_commands:
image = image.dockerfile_commands(dockerfile_commands)
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res.append(modal.Secret.from_dict({key_: val}))
return res
def create_volume(self, volume_config):
name = volume_config.name
mount = volume_config.mount
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
def get_ephemeral_disk_size(self):
return 1000 * 525 # 1 TiB
def get_preprocess_timeout(self):
if self.config.timeout_preprocess:
return int(self.config.timeout_preprocess)
return 60 * 60 * 3 # 3 hours
def get_preprocess_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
if self.config.memory_preprocess:
memory = int(self.config.memory_preprocess)
return 1024 * memory
def get_preprocess_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=8.0,
ephemeral_disk=self.get_ephemeral_disk_size(),
memory=self.get_preprocess_memory(),
timeout=self.get_preprocess_timeout(),
secrets=self.get_secrets(),
)
def preprocess(self, config_yaml: str, *args, **kwargs):
modal_fn = self.get_preprocess_env()(_preprocess)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
**kwargs,
)
def get_train_timeout(self):
if self.config.timeout:
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
if family == "l40s":
return modal.gpu.L40S(count=count)
if family in ["a100", "a100-40gb"]:
return modal.gpu.A100(count=count, size="40GB")
if family == "a100-80gb":
return modal.gpu.A100(count=count, size="80GB")
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":
return modal.gpu.L4(count=count)
raise ValueError(f"Unsupported GPU family: {family}")
def get_train_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
memory=self.get_train_memory(),
timeout=self.get_train_timeout(),
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def lm_eval(self, config_yaml: str):
modal_fn = self.get_train_env()(_lm_eval)
with modal.enable_output():
with self.app.run(detach=True):
if self.config.get("spawn", False):
modal_fn_exec = modal_fn.spawn
else:
modal_fn_exec = modal_fn.remote
modal_fn_exec(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def _preprocess(config_yaml: str, volumes=None):
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
run_folder,
volumes,
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
if accelerate:
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)
def _lm_eval(config_yaml: str, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

217
src/axolotl/cli/config.py Normal file
View File

@@ -0,0 +1,217 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Union
from urllib.parse import urlparse
import requests
import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
First, determines if the passed config is a valid HTTPS URL. Then, attempts to query
for it and parse its content, first as JSON, then as YAML (YAML is preferred).
Finally, the parsed content is written to a local file and its path is returned.
Args:
config: HTTPS URL to a YAML or JSON file.
Returns:
Either the original `config` if it's not a valid HTTPS URL, or the path to the
downloaded remote config.
Raises:
ValueError: If the remote configuration is neither valid JSON or YAML.
RuntimeError: If some request-related exception occurs from the file download.
Exception: Catch-all for any other exception.
"""
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly
# considered YAML.
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML.
# This can happen when you forget to point to a raw GitHub link.
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def choose_config(path: Path) -> str:
"""
Helper method for choosing a `axolotl` config YAML file (considering only files
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
`path`, the user is prompted to choose one.
Args:
path: Directory in which config file(s) are stored.
Returns:
Path to either (1) the sole YAML file, or (2) if more than one YAML files exist,
the user-selected YAML file.
Raises:
ValueError: If no YAML files are found in the given `path`.
"""
yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def prepare_plugins(cfg: DictDefault):
"""
Registers the plugins for the given configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
Args:
config: Path (local or remote) to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Returns:
`DictDefault` mapping configuration keys to values.
"""
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg

View File

@@ -1,6 +1,5 @@
"""
CLI to run training on a model
"""
"""CLI to run evaluation on a model."""
import logging
from pathlib import Path
from typing import Union
@@ -9,35 +8,48 @@ import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.evaluate")
LOG = logging.getLogger(__name__)
def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
evaluation metrics on the given dataset(s) and writes them to disk.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_evaluate`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)

View File

@@ -1,32 +1,267 @@
"""
CLI to run inference on a trained model
"""
"""CLI to run inference on a trained model."""
import importlib
import logging
import sys
from pathlib import Path
from threading import Thread
from typing import Union
import fire
import torch
import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
def get_multi_line_input() -> str:
"""
Gets multi-line input from terminal.
Returns:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
return instruction
def do_inference(
*,
cfg: DictDefault,
cli_args: InferenceCliArgs,
):
"""
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: InferenceCliArgs,
):
"""
Runs inference in a Gradio interface. User input is accepted, a chat template is
(optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def do_cli(
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
) -> None:
"""
Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser(InferenceCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.inference = True
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -1,18 +1,21 @@
"""CLI definition for various axolotl commands."""
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
import click
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
filter_none_kwargs,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -25,45 +28,89 @@ def cli():
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
"""
Preprocess datasets before training.
from axolotl.cli.preprocess import do_cli
Args:
config: Path to `axolotl` config YAML file.
cloud: Path to a cloud accelerator configuration file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
do_cli_preprocess(cloud_config=cloud, config=config)
else:
from axolotl.cli.train import do_cli
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
"""
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
@@ -73,10 +120,17 @@ def train(config: str, accelerate: bool, **kwargs):
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@filter_none_kwargs
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
"""
Evaluate a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
@@ -96,81 +150,33 @@ def evaluate(config: str, accelerate: bool, **kwargs):
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(
config: str,
accelerate: bool,
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["base_model"] = base_model
@filter_none_kwargs
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
"""
Run inference with a trained model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
if gradio:
base_cmd.append("--gradio")
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli
do_cli(config=config, **kwargs)
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command()
@@ -180,20 +186,19 @@ def shard(config: str, accelerate: bool, **kwargs):
default=True,
help="Use accelerate launch for weight merging",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@filter_none_kwargs
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
"""
Merge sharded FSDP model weights.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = [
"accelerate",
@@ -213,28 +218,18 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing the LoRA model to merge",
)
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs) -> None:
"""
Merge trained LoRA adapters into a base model.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
from axolotl.cli.merge_lora import do_cli
do_cli(config=config, **kwargs)
@@ -243,17 +238,24 @@ def merge_lora(
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
def fetch(directory: str, dest: Optional[str]) -> None:
"""
Fetch example configs or other resources.
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
Args:
directory: One of `examples`, `deepspeed_configs`.
dest: Optional destination directory.
"""
fetch_from_github(f"{directory}/", dest)
cli.add_command(lm_eval)
def main():
cli()

View File

@@ -1,6 +1,6 @@
"""
CLI to run merge a trained LoRA into a base model
"""
"""CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path
from typing import Union
@@ -8,14 +8,58 @@ import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
along with the LoRA adapters to combine them into a single base model.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
model.to(dtype=cfg.torch_dtype)
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
config values will be overwritten to allow the LoRA merge logic to work as expected
(`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
@@ -46,7 +90,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
do_merge_lora(cfg=parsed_cfg)
if __name__ == "__main__":

View File

@@ -1,6 +1,5 @@
"""
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
"""
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json
import logging
import os
@@ -25,16 +24,15 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
LOG = logging.getLogger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""
A custom planner to cast tensors to bfloat16 on the fly during loading.
"""
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16))
@@ -45,11 +43,19 @@ def _distributed_checkpoint_to_merged_weights(
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
):
) -> Path:
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
Path where model is saved.
"""
state_dict: Dict = {}
@@ -79,6 +85,7 @@ def _distributed_checkpoint_to_merged_weights(
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Save index if sharded
index = None
if state_dict_split.is_sharded:
@@ -135,6 +142,9 @@ def merge_fsdp_weights(
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
Raises:
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
@@ -178,18 +188,21 @@ def merge_fsdp_weights(
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(
config,
**kwargs,
)
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
merge_fsdp_weights(

View File

@@ -1,6 +1,5 @@
"""
CLI to run training on a model
"""
"""CLI to run preprocessing of a dataset."""
import logging
import warnings
from pathlib import Path
@@ -13,34 +12,31 @@ from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.cli.args import PreprocessCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Preprocessing-specific CLI arguments.
"""
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
if not cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
@@ -48,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if cfg.rl:
load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
load_datasets(cfg=cfg, cli_args=cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
if cli_args.download:
model_name = cfg.base_model
with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
@@ -74,11 +70,30 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`"
+ Fore.RESET
)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_preprocess(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,45 +0,0 @@
"""
CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.scripts")
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,50 +1,47 @@
"""
CLI to run training on a model
"""
"""CLI to run training on a model."""
import logging
from pathlib import Path
from typing import Union
import fire
from accelerate import Accelerator
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.train")
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Trains a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
manager's `post_train_unload` once training completes.
def do_train(cfg, cli_args) -> None:
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
@@ -53,6 +50,64 @@ def do_train(cfg, cli_args) -> None:
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cfg.use_ray:
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
trainer = TorchTrainer(
ray_train_func,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=parsed_cfg.ray_num_workers,
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
use_gpu=True,
),
run_config=RunConfig(
name=parsed_cfg.ray_run_name,
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
),
)
return trainer.fit()
return do_train(parsed_cfg, parsed_cli_args)
def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed:
cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
kwargs["cfg"] = cfg
do_train(**kwargs)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,32 +1,84 @@
"""Utility methods for axoltl CLI."""
"""Utility methods for axolotl CLI."""
import concurrent.futures
import dataclasses
import hashlib
import json
import logging
import typing
from functools import wraps
from pathlib import Path
from types import NoneType
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
LOG = logging.getLogger("axolotl.cli.utils")
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)
def add_options_from_dataclass(config_class: Type[Any]):
"""Create Click options from the fields of a dataclass."""
def strip_optional_type(field_type: type | typing._SpecialForm | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
def decorator(function):
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
field_type = strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
@@ -44,18 +96,29 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]):
"""Create Click options from the fields of a Pydantic model."""
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
def decorator(function):
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -66,13 +129,23 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
"""Build command list from base command and options."""
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
@@ -92,18 +165,18 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> Tuple[str, str]:
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha)
raw_base_url: Base URL for raw GitHub content
dest_path: Local destination directory
dir_prefix: Directory prefix to filter files
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
@@ -145,16 +218,17 @@ def download_file(
def fetch_from_github(
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
dest_dir: Local destination directory
max_workers: Maximum number of concurrent downloads
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
@@ -179,7 +253,7 @@ def fetch_from_github(
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: Dict[str, List[str]] = {
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
@@ -216,3 +290,28 @@ def fetch_from_github(
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -1,69 +0,0 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
inference = getattr(cli_args, "inference", False)
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -0,0 +1,140 @@
"""Dataset loading utilities."""
import logging
import math
import random
from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
@dataclass
class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
def load_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def load_preference_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -45,6 +44,8 @@ from trl import (
KTOTrainer,
ORPOConfig,
ORPOTrainer,
PRMConfig,
PRMTrainer,
RewardConfig,
RewardTrainer,
)
@@ -244,6 +245,10 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
@@ -339,6 +344,13 @@ class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
@dataclass
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
"""
PRM config for PRM training
"""
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
@@ -462,11 +474,95 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
@@ -480,59 +576,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -549,6 +599,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -608,8 +659,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
@@ -978,12 +1035,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1079,6 +1131,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
@@ -1167,22 +1220,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1191,22 +1228,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1215,49 +1236,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1266,22 +1244,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1290,14 +1252,13 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]
class TrainerBuilderBase(abc.ABC):
@@ -1433,7 +1394,8 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
"""
def get_callbacks(self):
@@ -1508,6 +1470,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -1764,6 +1728,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1897,11 +1862,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
@@ -1935,9 +1902,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not self.cfg.reward_model:
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not self.cfg.reward_model:
if not (self.cfg.reward_model or self.cfg.process_reward_model):
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
@@ -1948,8 +1915,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
@@ -1977,6 +1946,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":
@@ -2035,7 +2008,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for DPO Trainer
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
"""
def get_callbacks(self):

View File

@@ -52,6 +52,7 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,

View File

@@ -9,7 +9,6 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
@@ -62,16 +61,13 @@ def evaluate_dataset(
return metrics
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
Args:
cfg: Configuration dictionary
cli_args: Command line arguments
dataset_meta: Dataset metadata containing training and evaluation datasets
cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns:
Tuple containing:
@@ -102,9 +98,7 @@ def evaluate(
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, _ = load_model(cfg, tokenizer, processor=processor)
# Set up trainer
trainer = setup_trainer(

View File

@@ -22,13 +22,6 @@ import inspect
import logging
import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
@@ -46,6 +39,13 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -2,9 +2,9 @@
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
@@ -18,25 +18,20 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)
if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)

View File

@@ -13,3 +13,5 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True
lm_eval_model: Optional[str] = None

View File

@@ -0,0 +1,119 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime
from typing import Optional
import click
import yaml
from axolotl.utils.dict import DictDefault
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
flash_attention=False,
output_dir="./",
batch_size=8,
wandb_project=None,
wandb_entity=None,
wandb_name=None,
model=None,
revision=None,
apply_chat_template=None,
fewshot_as_multiturn=None,
):
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
if isinstance(tasks, str):
tasks = [tasks]
for task in tasks:
num_fewshot = "-1"
task_parts = task.split(":")
task_name = task_parts[0]
if len(task_parts) == 2:
task_name, num_fewshot = task_parts
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
tasks_str = ",".join(tasks_list)
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
pretrained = "pretrained="
pretrained += model if model else output_dir
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
revision = f",revision={revision}" if revision else ""
output_path = output_dir
output_path += "" if output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
lm_eval_args = [
"lm_eval",
"--model",
"hf",
"--model_args",
f"{pretrained}{fa2}{dtype}{revision}",
"--tasks",
tasks_str,
"--batch_size",
str(batch_size),
"--output_path",
output_path,
]
wandb_args = []
if wandb_project:
wandb_args.append(f"project={wandb_project}")
if wandb_entity:
wandb_args.append(f"entity={wandb_entity}")
if wandb_name:
wandb_args.append(f"name={wandb_name}")
if wandb_args:
lm_eval_args.append("--wandb_args")
lm_eval_args.append(",".join(wandb_args))
if apply_chat_template:
lm_eval_args.append("--apply_chat_template")
if num_fewshot_val:
lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val))
if apply_chat_template and fewshot_as_multiturn:
lm_eval_args.append("--fewshot_as_multiturn")
yield lm_eval_args
@click.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
def lm_eval(config: str, cloud: Optional[str] = None):
"""
use lm eval to evaluate a trained language model
"""
if cloud:
from axolotl.cli.cloud import do_cli_lm_eval
do_cli_lm_eval(cloud_config=cloud, config=config)
else:
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -1,308 +0,0 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""
import transformers.modeling_flash_attention_utils
def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)
return _flash_attention_forward(*args, **kwargs)
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)

View File

@@ -0,0 +1,67 @@
"""
see https://github.com/huggingface/transformers/pull/35834
"""
import logging
from functools import partial
from typing import Optional
import torch
logger = logging.getLogger(__name__)
def fixed_fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
preferred_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
preferred_dtype (`torch.dtype`, *optional*):
The preferred dtype to convert the attention tensors to regardless of the
target dtype.
"""
if target_dtype is None and preferred_dtype is None:
return query, key, value
if preferred_dtype and target_dtype != preferred_dtype:
target_dtype = preferred_dtype
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
if any(module.dtype == torch.float32 for module in [query, key, value]):
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
return query, key, value
def patch_fa_peft_integration():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
fixed_fa_peft_integration_check, preferred_dtype=None
)

View File

@@ -1,9 +1,7 @@
"""module for patching with unsloth optimizations"""
import inspect
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,7 +1,8 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import re
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -0,0 +1,116 @@
"""
Module for stepwise datasets, typically including a prompt and reasoning traces,
and (optionally) per-step, or per-prompt-trace labels for reward modelling.
"""
from itertools import chain
from typing import Dict, List, Optional, Union
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompt_tokenizers import IGNORE_INDEX
from axolotl.utils.dict import DictDefault
class StepwiseSupervisedPromptTokenizingStrategy:
"""
Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.
These datasets should include the following columns:
- prompt: the prompt text
- completions: a list of `n` completion steps
- labels: a list of `n` labels indicating the "correctness" of each step
"""
def __init__(
self,
tokenizer,
sequence_len: int = 2048,
step_separator: str = "\n",
max_completion_length: Optional[int] = None,
train_on_last_step_only: bool = False,
):
self.tokenizer = tokenizer
self.sequence_len = sequence_len
self.step_separator = step_separator
self.max_completion_length = max_completion_length
self.train_on_last_step_only = train_on_last_step_only
def tokenize_prompt(
self, prompt: Dict[str, Union[str, List[str]]]
) -> BatchEncoding:
# Inspired by TRL's PRMTRainer
# https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206
prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[
"input_ids"
]
completions_ids = [
self.tokenizer(completion, add_special_tokens=False)["input_ids"]
for completion in prompt["completions"]
]
# Handle labels
if self.train_on_last_step_only:
labels = [IGNORE_INDEX] * (len(prompt["labels"]) - 1) + [
int(prompt["labels"][-1])
]
else:
labels = [int(label) for label in prompt["labels"]]
# Add step separators
separator_ids = self.tokenizer.encode(
self.step_separator, add_special_tokens=False
)
completions_ids = [completion + separator_ids for completion in completions_ids]
# Create step-wise labels
labels = [
[IGNORE_INDEX] * (len(completion) - 1) + [label] # type: ignore
for completion, label in zip(completions_ids, labels)
]
# Join all steps
completion_ids = list(chain(*completions_ids))
labels = list(chain(*labels)) # type: ignore
# Handle max lengths
if self.max_completion_length:
completion_ids = completion_ids[: self.max_completion_length]
labels = labels[: self.max_completion_length]
# Add BOS token if model has one
if self.tokenizer.bos_token_id is not None:
prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids
# Combine prompt and completion
input_ids = prompt_ids + completion_ids
full_labels = [IGNORE_INDEX] * len(prompt_ids) + labels
# Apply max sequence length
if self.sequence_len:
input_ids = input_ids[: self.sequence_len]
full_labels = full_labels[: self.sequence_len]
return {
"input_ids": input_ids,
"labels": full_labels,
"attention_mask": [1] * len(input_ids),
}
@property
def supports_batched(self):
return False
def load(
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
ds_cfg: DictDefault,
) -> StepwiseSupervisedPromptTokenizingStrategy:
return StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
cfg.sequence_len,
step_separator=ds_cfg.get("step_separator", "\n"),
max_completion_length=ds_cfg.max_completion_length,
train_on_last_step_only=ds_cfg.get("train_on_last_step_only", False),
)

View File

@@ -5,21 +5,19 @@ import os
import signal
import sys
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Tuple, Union
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
@@ -39,22 +37,11 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger("axolotl.train")
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
LOG = get_logger(__name__)
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Load tokenizer
LOG.debug(
@@ -93,9 +80,7 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
@@ -107,9 +92,7 @@ def train(
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
safe_serialization = cfg.save_safetensors is True
@@ -158,7 +141,9 @@ def train(
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
@@ -276,7 +261,7 @@ def train(
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import gc
import logging
import math
import os
import traceback
from shutil import copyfile
@@ -830,13 +829,6 @@ class SaveModelCallback(TrainerCallback):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
elif (
args.save_strategy == IntervalStrategy.STEPS
and state.save_steps < 1.0
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
):
# workaround to save model on fractional save_steps
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs

View File

@@ -1,4 +1,5 @@
"""Module for working with config dicts"""
import json
import logging
import os
from typing import Optional
@@ -56,33 +57,10 @@ def choose_device(cfg):
cfg.device_map = None
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if cfg.bf16 == "auto":
def resolve_dtype(cfg):
if (
cfg.bf16 == "auto" and not cfg.use_ray
): # if we use ray we want to defer this check to the worker node
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
cfg.bf16 = True
@@ -110,6 +88,43 @@ def normalize_config(cfg):
else:
cfg.torch_dtype = torch.float32
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
)
if cfg.eval_batch_size is None:
cfg.eval_batch_size = cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
"sacrebleu",
"comet",
"ter",
"chrf",
]
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if not cfg.use_ray:
# delay resolving dtype until on worker node when launching with ray
resolve_dtype(cfg)
if cfg.deepspeed:
if isinstance(cfg.deepspeed, str) and os.path.exists(cfg.deepspeed):
ds_config_path = cfg.deepspeed
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step

View File

@@ -128,6 +128,8 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
skip: Optional[int] = None
class UserDefinedPrompterType(BaseModel):
@@ -145,6 +147,14 @@ class UserDefinedPrompterType(BaseModel):
field: Optional[str] = None
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: List[str]
lr: float
class SFTDataset(BaseModel):
"""SFT configuration subset"""
@@ -226,6 +236,18 @@ class DPODataset(BaseModel):
revision: Optional[str] = None
class StepwiseSupervisedDataset(BaseModel):
"""Stepwise supervised dataset configuration subset"""
path: Optional[str] = None
split: Optional[str] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
step_separator: Optional[str] = None
max_completion_length: Optional[int] = None
train_on_last_step_only: Optional[bool] = None
class UserDefinedKTOType(BaseModel):
"""User defined typing for KTO"""
@@ -366,6 +388,13 @@ class LoraConfig(BaseModel):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
@model_validator(mode="before")
@classmethod
def validate_lora_dropout(cls, data):
if data.get("adapter") is not None and data.get("lora_dropout") is None:
data["lora_dropout"] = 0.0
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -466,12 +495,13 @@ class HyperparametersConfig(BaseModel):
cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None
lr_groups: Optional[List[LrGroup]] = None
adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None
adam_beta2: Optional[float] = None
max_grad_norm: Optional[float] = None
num_epochs: int = Field(default=1)
num_epochs: float = Field(default=1.0)
@field_validator("batch_size")
@classmethod
@@ -498,7 +528,7 @@ class ModelOutputConfig(BaseModel):
output_dir: str = Field(default="./model-out")
hub_model_id: Optional[str] = None
hub_strategy: Optional[str] = None
save_safetensors: Optional[bool] = None
save_safetensors: Optional[bool] = True
class MLFlowConfig(BaseModel):
@@ -577,6 +607,30 @@ class GradioConfig(BaseModel):
gradio_temperature: Optional[float] = None
class RayConfig(BaseModel):
"""Ray launcher configuration subset"""
use_ray: bool = Field(default=False)
ray_run_name: Optional[str] = Field(
default=None,
metadata={
"help": "The training results will be saved at `saves/ray_run_name`."
},
)
ray_num_workers: int = Field(
default=1,
metadata={
"help": "The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker: dict = Field(
default_factory=lambda: {"GPU": 1},
metadata={
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
@@ -589,6 +643,7 @@ class AxolotlInputConfig(
CometConfig,
LISAConfig,
GradioConfig,
RayConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
@@ -608,12 +663,14 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
@@ -697,6 +754,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,

View File

@@ -18,10 +18,14 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples[text_column],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
@@ -30,6 +34,13 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = []
new_labels = []
new_attention_mask = []
@@ -180,7 +191,7 @@ def wrap_pretraining_dataset(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * batch_size,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
encode = functools.partial(
@@ -190,13 +201,17 @@ def wrap_pretraining_dataset(
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
@@ -230,9 +245,7 @@ def encode_packed_pretraining(
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = False,
group_size: int = 100000,
bin_size: int = 200,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
@@ -243,6 +256,9 @@ def encode_packed_pretraining(
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
)
sampler = MultipackBatchSampler(
@@ -250,8 +266,6 @@ def encode_packed_pretraining(
lengths=get_dataset_lengths(train_dataset),
batch_size=1,
batch_max_len=batch_size * max_seq_length,
group_size=group_size,
bin_size=bin_size,
drop_last=True,
)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg):
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -8,6 +8,8 @@ from typing import List, Tuple, Union
from datasets import (
Dataset,
DatasetDict,
Sequence,
Value,
concatenate_datasets,
load_dataset,
load_from_disk,
@@ -88,14 +90,19 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset
split = "train"
name = None
data_files = None
skip = 0
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
skip = cfg.pretraining_dataset[0]["skip"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
@@ -104,8 +111,14 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
iter_ds,
tokenizer,
cfg,
ds_wrapper_partial,
@@ -456,6 +469,17 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
elif config_dataset.type.startswith("stepwise_supervised"):
dataset_prompter = UnsupportedPrompter()
ds_strategy = load(config_dataset.type, tokenizer, cfg, config_dataset)
# we need to explicitly cast boolean labels to int
# for compatibility with how trl's PRMTrainer works
dataset = dataset.cast_column("labels", Sequence(Value("int64")))
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):

View File

@@ -107,6 +107,13 @@ def load_dataset_w_config(config_dataset, auth_token):
except (FileNotFoundError, ConnectionError):
pass
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
@@ -118,7 +125,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
try:
@@ -130,7 +137,7 @@ def load_dataset_w_config(config_dataset, auth_token):
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
**load_ds_kwargs,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
@@ -140,16 +147,13 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
**load_ds_kwargs,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -173,9 +177,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
@@ -184,9 +188,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
else:
if isinstance(config_dataset.data_files, str):
@@ -214,7 +218,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")

Some files were not shown because too many files have changed in this diff Show More