Compare commits

..

44 Commits

Author SHA1 Message Date
Dan Saunders
2daa94080c Merge branch 'main' into diff-transformer 2025-01-27 14:46:17 +00:00
Dan Saunders
0e9bfa6dee small fixes, improvements 2025-01-24 19:53:54 +00:00
Dan Saunders
ef38f10274 merging into main 2025-01-24 18:03:27 +00:00
Dan Saunders
66262c3092 moving out all diff attn code to plugin repo 2025-01-24 17:46:11 +00:00
Dan Saunders
016ba124e4 README update 2025-01-23 22:11:35 +00:00
Dan Saunders
7145d52d99 moving diff attn code to separate repo 2025-01-23 21:33:53 +00:00
Dan Saunders
28694219a5 inline comment change 2025-01-14 16:59:43 +00:00
Dan Saunders
fd8ad6fcbf fixing negative component mixing 2025-01-13 19:21:55 +00:00
Dan Saunders
661d71a14b adding diff attn negative component warmup (in progress) 2025-01-10 21:57:31 +00:00
Dan Saunders
6dd47edcb8 fire CLI fixes 2025-01-10 18:24:16 +00:00
Dan Saunders
7aca08ff60 adding guard statements 2025-01-10 16:39:21 +00:00
Dan Saunders
4f804f6d88 adding diff attn callback, adding documentation 2025-01-10 16:28:51 +00:00
Dan Saunders
443327c585 CLI build_command bugfix 2025-01-10 16:28:51 +00:00
Dan Saunders
70c4e6fbe6 updates and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
2a7f139ad2 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
332ce0ae85 fixes and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
e5fa842ff8 update 2025-01-10 16:28:51 +00:00
Dan Saunders
78e0ec0aa5 changes 2025-01-10 16:28:51 +00:00
Dan Saunders
3bc568eb27 adding registration function 2025-01-10 16:28:51 +00:00
Dan Saunders
eb6611d55f progress on modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
4ff3328e66 updated custom modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
a3fd5074a9 fix duplicate-code warnings 2025-01-10 16:28:51 +00:00
Dan Saunders
5b90da0be3 added modeling code; cleanup + refactor 2025-01-10 16:28:51 +00:00
Dan Saunders
fcbfa86373 refactor and fixing test isolation issues 2025-01-10 16:28:51 +00:00
Dan Saunders
0d56582090 adding yaml dumper preserving input config format 2025-01-10 16:28:51 +00:00
Dan Saunders
390cb5742e removing extra pytest xdist args 2025-01-10 16:28:51 +00:00
Dan Saunders
1d935f65c3 moving tests around for flash_attn install 2025-01-10 16:28:51 +00:00
Dan Saunders
66176b3e07 adding split_heads argument for retaining original (Q, K) dimensionanlity 2025-01-10 16:28:51 +00:00
Dan Saunders
505321ac95 isolating problematic test 2025-01-10 16:28:51 +00:00
Dan Saunders
0b382c88da fixes post-rebase 2025-01-10 16:28:51 +00:00
Dan Saunders
ea07a7086e plugin implementation 2025-01-10 16:28:51 +00:00
Dan Saunders
d22e1136bc convert-differential-transformer test coverage 2025-01-10 16:28:51 +00:00
Dan Saunders
63b8e42c6b duplicate code ignore 2025-01-10 16:28:51 +00:00
Dan Saunders
bda1eed59e differential flash attention 2; cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
41ebd93158 moving monkeypatch 2025-01-10 16:28:51 +00:00
Dan Saunders
4c050ce807 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
6665acf63d fix model save / load logic 2025-01-10 16:28:51 +00:00
Dan Saunders
2f9fa4c465 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
849bc94112 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
e484ec778d training fixes, patching, minor cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
df1504ae14 adding CLI command for convert-diff-transformer 2025-01-10 16:28:51 +00:00
Dan Saunders
7be0d7496c Adding script for doing conversion; fixes and updates 2025-01-10 16:28:51 +00:00
Dan Saunders
13cdffa91f initial diff attn layer / model conversion implementation (support for llama arch) 2025-01-10 16:28:51 +00:00
Dan Saunders
7a4b296f60 Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-01-10 16:28:51 +00:00
105 changed files with 1284 additions and 2566 deletions

View File

@@ -15,7 +15,7 @@ First of all, thank you for your interest in contributing to axolotl! We appreci
- [Commit Messages](#commit-messages)
- [Additional Resources](#additional-resources)
## Code of Conduct
## Code of Conductcode
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.

View File

@@ -22,6 +22,18 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.10"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""

View File

@@ -15,6 +15,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -72,6 +82,16 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -125,10 +145,10 @@ jobs:
strategy:
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.3.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -20,6 +20,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -12,6 +12,17 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -65,6 +76,17 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -26,7 +26,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
@@ -98,6 +98,13 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
@@ -228,7 +228,6 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
@@ -245,6 +244,12 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -269,7 +274,6 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |

3
.gitignore vendored
View File

@@ -186,3 +186,6 @@ out/
# vim
*.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -19,7 +19,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 6.0.0
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint

785
README.md
View File

@@ -1,8 +1,8 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
@@ -19,99 +19,235 @@
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
</p>
Axolotl is a tool designed to streamline post-training for various AI models.
Post-training refers to any modifications or additional training performed on
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
techniques. With support for multiple model architectures and training configurations,
Axolotl makes it easy to get started with these techniques.
Axolotl is designed to work with YAML config files that contain everything you need to
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
and much more.
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb, mlflow or Comet
- And more!
## 🚀 Quick Start
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
</a>
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
<table>
<tr>
<td>
### Installation
## Table of Contents
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Quickstart ⚡](#quickstart-)
- [Edge Builds](#edge-builds-)
- [Axolotl CLI Usage](#axolotl-cli-usage)
- [Badge ❤🏷️](#badge-)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [Axolotl supports](#axolotl-supports)
- [Advanced Setup](#advanced-setup)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu)
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [LambdaLabs](#lambdalabs)
- [GCP](#gcp)
- [Windows](#windows)
- [Mac](#mac)
- [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset)
- [Config](#config)
- [All Config Options](#all-config-options)
- [Train](#train)
- [Preprocess dataset](#preprocess-dataset)
- [Multi-GPU](#multi-gpu)
- [DeepSpeed](#deepspeed)
- [FSDP](#fsdp)
- [FSDP + QLoRA](#fsdp--qlora)
- [Weights \& Biases Logging](#weights--biases-logging)
- [Special Tokens](#special-tokens)
- [Liger Kernel](#liger-kernel)
- [Inference Playground](#inference-playground)
- [Merge LORA to base](#merge-lora-to-base)
- [Common Errors 🧰](#common-errors-)
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need help? 🙋](#need-help-)
```shell
</td>
<td>
<div align="center">
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
<div>
<p>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p>
<p>
Go ahead and Axolotl questions!!
</p>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
</div>
</div>
</td>
</tr>
</table>
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
```bash
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
# download examples and optionally deepspeed configs to the local path
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
```
Other installation approaches are described [here](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html).
### Your First Fine-tune
```shell
# Fetch axolotl examples
axolotl fetch examples
# Or, specify a custom path
axolotl fetch examples --dest path/to/folder
# Train a model using LoRA
# finetune using lora
axolotl train examples/llama-3/lora-1b.yml
```
That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/getting-started.html) for a more detailed walkthrough.
### Edge Builds 🏎️
## ✨ Key Features
If you're looking for the latest features and updates between releases, you'll need to install
from source.
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
- **Easy Configuration**: Simple YAML files to control your training setup
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
- **Flexible Dataset Handling**: Use various formats and custom datasets
- **Cloud Ready**: Run on cloud platforms or local hardware
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
## 📚 Documentation
### Axolotl CLI Usage
We now support a new, more streamlined CLI using [click](https://click.palletsprojects.com/en/stable/).
- [Installation Options](https://axolotl-ai-cloud.github.io/axolotl/docs/installation.html) - Detailed setup instructions for different environments
- [Configuration Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html) - Full configuration options and examples
- [Dataset Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) - Supported formats and how to use them
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/llama-3/lora-1b.yml
## 🤝 Getting Help
# finetune lora
axolotl train examples/llama-3/lora-1b.yml
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
- Check out our [Examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/) directory
- Read our [Debugging Guide](https://axolotl-ai-cloud.github.io/axolotl/docs/debugging.html)
- Need dedicated support? Please contact [wing@axolotl.ai](mailto:wing@axolotl.ai) for options
# inference
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out"
## 🌟 Contributing
# gradio
axolotl inference examples/llama-3/lora-1b.yml \
--lora-model-dir="./outputs/lora-out" --gradio
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
## Supported Models
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
local machine. This will come in handy when installing `axolotl` from PyPI.
```bash
# Fetch example YAML files (stores in "examples/" folder)
axolotl fetch examples
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
axolotl fetch deepspeed_configs
# Optionally, specify a destination folder
axolotl fetch examples --dest path/to/folder
```
### Legacy Usage
<details>
<summary>Click to Expand</summary>
While the Axolotl CLI is the preferred method for interacting with axolotl, we
still support the legacy `-m axolotl.cli.*` usage.
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/llama-3/lora-1b.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/llama-3/lora-1b.yml
# inference
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/llama-3/lora-1b.yml \
--lora_model_dir="./outputs/lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/llama-3/lora-1b.yml
```
</details>
## Badge ❤🏷️
Building something cool with Axolotl? Consider adding a badge to your model card.
```markdown
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
```
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
## Sponsors 🤝❤
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
---
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
---
## Contributing 🤝
Please read the [contributing guide](./.github/CONTRIBUTING.md)
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
PRs are **greatly welcome**!
Please run the quickstart instructions followed by the below to setup env:
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
@@ -136,16 +272,523 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
❌: not supported
❓: untested
## ❤️ Sponsors
## Advanced Setup
Thank you to our sponsors who help make Axolotl possible:
### Environment
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
fine-tune large language models, run protein folding simulations, and much more.
#### Docker
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
## 📜 License
Or run on the current files for development:
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
```sh
docker compose up -d
```
>[!Tip]
> If you want to debug axolotl or prefer to use Docker as your development environment, see the [debugging guide's section on Docker](docs/debugging.qmd#debugging-with-docker).
<details>
<summary>Docker advanced</summary>
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
```
It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
* The `--privileged` flag gives all capabilities to the container.
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
</details>
#### Conda/Pip venv
1. Install python >=**3.10**
2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install Axolotl along with python dependencies
```bash
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Huggingface to use gated models/datasets.
```bash
huggingface-cli login
```
Get the token at huggingface.co/settings/tokens
#### Cloud GPU
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU
##### LambdaLabs
<details>
<summary>Click to Expand</summary>
1. Install python
```bash
sudo apt update
sudo apt install -y python3.10
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
sudo update-alternatives --config python # pick 3.10 if given option
python -V # should be 3.10
```
2. Install pip
```bash
wget https://bootstrap.pypa.io/get-pip.py
python get-pip.py
```
3. Install Pytorch https://pytorch.org/get-started/locally/
4. Follow instructions on quickstart.
5. Run
```bash
pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy
```
6. Set path
```bash
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
```
</details>
##### GCP
<details>
<summary>Click to Expand</summary>
Use a Deeplearning linux OS with cuda and pytorch installed. Then follow instructions on quickstart.
Make sure to run the below to uninstall xla.
```bash
pip uninstall -y torch_xla[tpu]
```
</details>
#### Windows
Please use WSL or Docker!
#### Mac
Use the below instead of the install method in QuickStart.
```
pip3 install --no-build-isolation -e '.'
```
More info: [mac.md](/docs/mac.qmd)
#### Google Colab
Please use this example [notebook](examples/colab-notebooks/colab-axolotl-example.ipynb).
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
# Managed spot (auto-recovery on preemption)
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
#### Launching on public clouds via dstack
To launch on GPU instance (both on-demand and spot instances) on public clouds (GCP, AWS, Azure, Lambda Labs, TensorDock, Vast.ai, and CUDO), you can use [dstack](https://dstack.ai/).
Write a job description in YAML as below:
```yaml
# dstack.yaml
type: task
image: axolotlai/axolotl-cloud:main-latest
env:
- HUGGING_FACE_HUB_TOKEN
- WANDB_API_KEY
commands:
- accelerate launch -m axolotl.cli.train config.yaml
ports:
- 6006
resources:
gpu:
memory: 24GB..
count: 2
```
then, simply run the job with `dstack run` command. Append `--spot` option if you want spot instance. `dstack run` command will show you the instance with cheapest price across multi cloud services:
```bash
pip install dstack
HUGGING_FACE_HUB_TOKEN=xxx WANDB_API_KEY=xxx dstack run . -f dstack.yaml # --spot
```
For further and fine-grained use cases, please refer to the official [dstack documents](https://dstack.ai/docs/) and the detailed description of [axolotl example](https://github.com/dstackai/dstack/tree/master/examples/fine-tuning/axolotl) on the official repository.
### Dataset
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config
See [examples](examples) for quick start. It is recommended to duplicate and modify to your needs. The most important options are:
- model
```yaml
base_model: ./llama-7b-hf # local or huggingface repo
```
Note: The code will load the right architecture.
- dataset
```yaml
datasets:
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca
# huggingface repo with specific configuration/subset
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
- path: bigcode/commitpackft
name:
- ruby
- python
- typescript
type: ... # unimplemented custom format
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
- path: ...
type: chat_template
chat_template: chatml # defaults to tokenizer's chat_template
# local
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default / gcs will attempt to load from gcloud creds, google metadata service, or anon
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
- loading
```yaml
load_in_4bit: true
load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
Note: Repo does not do 4-bit quantization.
- lora
```yaml
adapter: lora # 'qlora' or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
```
#### All Config Options
See [these docs](docs/config.qmd) for all config options.
### Train
Run
```bash
accelerate launch -m axolotl.cli.train your_config.yml
```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `dataset_prepared_path:` to a local folder for saving and loading pre-tokenized dataset.
- (Optional): Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- (Optional): Use `--debug` to see preprocessed examples.
```bash
python -m axolotl.cli.preprocess your_config.yml
```
#### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
is the recommended multi-GPU option currently because FSDP may experience
[loss instability](https://github.com/huggingface/transformers/issues/26498).
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed_configs/zero1.json
```
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
##### FSDP
- llama FSDP
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
##### FSDP + QLoRA
Axolotl supports training with FSDP and QLoRA, see [these docs](docs/fsdp_qlora.qmd) for more information.
##### Weights & Biases Logging
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
- wandb options
```yaml
wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
```
##### Comet Logging
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
- wandb options
```yaml
use_comet:
comet_api_key:
comet_workspace:
comet_project_name:
comet_experiment_key:
comet_mode:
comet_online:
comet_experiment_config:
```
##### Special Tokens
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
```yml
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens: # these are delimiters
- "<|im_start|>"
- "<|im_end|>"
```
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
##### Liger Kernel
Liger Kernel: Efficient Triton Kernels for LLM Training
https://github.com/linkedin/Liger-Kernel
Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training.
It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel
composes well and is compatible with both FSDP and Deepspeed.
```yaml
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
### Inference Playground
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
The config file is the same config file used for training.
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
- Pretrained LORA:
```bash
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
```
- Full weights finetune:
```bash
python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
```
- Full weights finetune w/ a prompt from a text file:
```bash
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True
```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```
Please use `--sample_packing False` if you have it on and receive the error similar to below:
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
```
You may need to use the `gpu_memory_limit` and/or `lora_on_cpu` config options to avoid running out of memory. If you still run out of CUDA memory, you can try to merge in system RAM with
```bash
CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
```
although this will be very slow, and using the config options above are recommended instead.
## Common Errors 🧰
See also the [FAQ's](./docs/faq.qmd) and [debugging guide](docs/debugging.qmd).
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below
- `micro_batch_size`
- `eval_batch_size`
- `gradient_accumulation_steps`
- `sequence_len`
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
Using adamw_bnb_8bit might also save you some memory.
> `failed (exitcode: -9)`
Usually means your system has run out of system memory.
Similarly, you should consider reducing the same settings as when you run out of VRAM.
Additionally, look into upgrading your system RAM which should be simpler than GPU upgrades.
> RuntimeError: expected scalar type Float but found Half
Try set `fp16: true`
> NotImplementedError: No operator found for `memory_efficient_attention_forward` ...
Try to turn off xformers.
> accelerate config missing
It's safe to ignore it.
> NCCL Timeouts during training
See the [NCCL](docs/nccl.qmd) guide.
### Tokenization Mismatch b/w Inference & Training
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/finetuning/05_tokenizer_gotchas.html) for a concrete example.
## Debugging Axolotl
See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
## Need help? 🙋
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
Need dedicated support? Please contact us at [wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.

View File

@@ -28,21 +28,16 @@ website:
- section: "How-To Guides"
contents:
# TODO Edit folder structure after we have more docs.
- docs/getting-started.qmd
- docs/installation.qmd
- docs/debugging.qmd
- docs/inference.qmd
- docs/multipack.qmd
- docs/fsdp_qlora.qmd
- docs/input_output.qmd
- docs/rlhf.qmd
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-gpu.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- docs/amd_hpc.qmd
- docs/ray-integration.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
@@ -50,6 +45,7 @@ website:
- docs/config.qmd
- docs/faq.qmd
format:
html:
theme: materia

View File

@@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/

View File

@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
@@ -23,8 +23,8 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),

View File

@@ -23,8 +23,8 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
@@ -38,12 +38,16 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[])

View File

@@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh

View File

@@ -1,256 +0,0 @@
# Axolotl CLI Documentation
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
the CLI commands, their usage, and common examples.
### Table of Contents
- Basic Commands
- Command Reference
- fetch
- preprocess
- train
- inference
- merge-lora
- merge-sharded-fsdp-weights
- evaluate
- lm-eval
- Legacy CLI Usage
- Remote Compute with Modal Cloud
- Cloud Configuration
- Running on Modal Cloud
- Cloud Configuration Options
### Basic Commands
All Axolotl commands follow this general structure:
```bash
axolotl <command> [config.yml] [options]
```
The config file can be local or a URL to a raw YAML file.
### Command Reference
#### fetch
Downloads example configurations and deepspeed configs to your local machine.
```bash
# Get example YAML files
axolotl fetch examples
# Get deepspeed config files
axolotl fetch deepspeed_configs
# Specify custom destination
axolotl fetch examples --dest path/to/folder
```
#### preprocess
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
```bash
# Basic preprocessing
axolotl preprocess config.yml
# Preprocessing with one GPU
CUDA_VISIBLE_DEVICES="0" axolotl preprocess config.yml
# Debug mode to see processed examples
axolotl preprocess config.yml --debug
# Debug with limited examples
axolotl preprocess config.yml --debug --debug-num-examples 5
```
Configuration options:
```yaml
dataset_prepared_path: Local folder for saving preprocessed data
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
```
#### train
Trains or fine-tunes a model using the configuration specified in your YAML file.
```bash
# Basic training
axolotl train config.yml
# Train and set/override specific options
axolotl train config.yml \
--learning-rate 1e-4 \
--micro-batch-size 2 \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --no-accelerate
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
```
#### inference
Runs inference using your trained model in either CLI or Gradio interface mode.
```bash
# CLI inference with LoRA
axolotl inference config.yml --lora-model-dir="./outputs/lora-out"
# CLI inference with full model
axolotl inference config.yml --base-model="./completed-model"
# Gradio web interface
axolotl inference config.yml --gradio \
--lora-model-dir="./outputs/lora-out"
# Inference with input from file
cat prompt.txt | axolotl inference config.yml \
--base-model="./completed-model"
```
#### merge-lora
Merges trained LoRA adapters into the base model.
```bash
# Basic merge
axolotl merge-lora config.yml
# Specify LoRA directory (usually used with checkpoints)
axolotl merge-lora config.yml --lora-model-dir="./lora-output/checkpoint-100"
# Merge using CPU (if out of GPU memory)
CUDA_VISIBLE_DEVICES="" axolotl merge-lora config.yml
```
Configuration options:
```yaml
gpu_memory_limit: Limit GPU memory usage
lora_on_cpu: Load LoRA weights on CPU
```
#### merge-sharded-fsdp-weights
Merges sharded FSDP model checkpoints into a single combined checkpoint.
```bash
# Basic merge
axolotl merge-sharded-fsdp-weights config.yml
```
#### evaluate
Evaluates a model's performance using metrics specified in the config.
```bash
# Basic evaluation
axolotl evaluate config.yml
```
#### lm-eval
Runs LM Evaluation Harness on your model.
```bash
# Basic evaluation
axolotl lm-eval config.yml
# Evaluate specific tasks
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
```
Configuration options:
```yaml
lm_eval_tasks: List of tasks to evaluate
lm_eval_batch_size: Batch size for evaluation
output_dir: Directory to save evaluation results
```
### Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
```bash
# Preprocess
python -m axolotl.cli.preprocess config.yml
# Train
accelerate launch -m axolotl.cli.train config.yml
# Inference
accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out"
# Gradio interface
accelerate launch -m axolotl.cli.inference config.yml \
--lora_model_dir="./outputs/lora-out" --gradio
```
### Remote Compute with Modal Cloud
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
cloud YAML file alongside your regular Axolotl config.
#### Cloud Configuration
Create a cloud config YAML with your Modal settings:
```yaml
# cloud_config.yml
provider: modal
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
gpu_count: 1 # Number of GPUs to use
timeout: 86400 # Maximum runtime in seconds (24 hours)
branch: main # Git branch to use (optional)
volumes: # Persistent storage volumes
- name: axolotl-cache
mount: /workspace/cache
env: # Environment variables
- WANDB_API_KEY
- HF_TOKEN
```
#### Running on Modal Cloud
Commands that support the --cloud flag:
```bash
# Preprocess on cloud
axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```
#### Cloud Configuration Options
```yaml
provider: compute provider, currently only `modal` is supported
gpu: GPU type to use
gpu_count: Number of GPUs (default: 1)
memory: RAM in GB (default: 128)
timeout: Maximum runtime in seconds
timeout_preprocess: Preprocessing timeout
branch: Git branch to use
docker_tag: Custom Docker image tag
volumes: List of persistent storage volumes
env: Environment variables to pass
secrets: Secrets to inject
```

View File

@@ -187,12 +187,6 @@ rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# reward modelling: `True` or `False`
reward_model:
# process reward modelling: `True` or `False`
process_reward_model:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py

View File

@@ -8,12 +8,14 @@ order: 3
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
## pygmalion
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.

View File

@@ -1,26 +0,0 @@
---
title: Stepwise Supervised Format
description: Format for datasets with stepwise completions and labels
order: 3
---
## Stepwise Supervised
The stepwise supervised format is designed for chain-of-thought (COT) reasoning
datasets where each example contains multiple completion steps and a preference label
for each step.
### Example
Here's a simple example of a stepwise supervised dataset entry:
```json
{
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": [
"The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.",
"Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."
],
"labels": [true, false]
}
```

View File

@@ -1,155 +0,0 @@
---
title: "Getting Started with Axolotl"
format:
html:
toc: true
toc-depth: 3
number-sections: true
execute:
enabled: false
---
This guide will walk you through your first model fine-tuning project with Axolotl.
## Quick Example {#sec-quick-example}
Let's start by fine-tuning a small language model using LoRA. This example uses a 1B parameter model to ensure it runs on most GPUs.
Assuming `axolotl` is installed (if not, see our [Installation Guide](installation.qmd))
1. Download example configs:
```shell
axolotl fetch examples
```
2. Run the training:
```shell
axolotl train examples/llama-3/lora-1b.yml
```
That's it! Let's understand what just happened.
## Understanding the Process {#sec-understanding}
### The Configuration File {#sec-config}
The YAML configuration file controls everything about your training. Here's what (part of) our example config looks like:
```yaml
base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
```
See our [Config options](config.qmd) for more details.
### Training {#sec-training}
When you run `axolotl train`, Axolotl:
1. Downloads the base model
2. (If specified) applies LoRA adapter layers
3. Loads and processes the dataset
4. Runs the training loop
5. Saves the trained model and / or LoRA weights
## Your First Custom Training {#sec-custom}
Let's modify the example for your own data:
1. Create a new config file `my_training.yml`:
```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1
adapter: lora
# Training settings
micro_batch_size: 2
num_epochs: 3
learning_rate: 0.0003
# Your dataset
datasets:
- path: my_data.jsonl # Your local data file
type: alpaca # Or other format
```
This specific config is for LoRA fine-tuning a model with instruction tuning data using
the `alpaca` dataset format, which has the following format:
```json
{
"instruction": "Write a description of alpacas.",
"input": "",
"output": "Alpacas are domesticated South American camelids..."
}
```
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
format):
```json
{"instruction": "Classify this text", "input": "I love this!", "output": "positive"}
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training:
```shell
axolotl train my_training.yml
```
## Common Tasks {#sec-common-tasks}
### Testing Your Model {#sec-testing}
After training, test your model:
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first:
```shell
axolotl preprocess my_training.yml
```
### Using a UI {#sec-ui}
Launch a Gradio interface:
```shell
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to:
- Try different model architectures
- Experiment with hyperparameters
- Use more advanced training methods
- Scale up to larger models
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 292 KiB

View File

@@ -1,148 +0,0 @@
---
title: "Inference Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
## Quick Start {#sec-quickstart}
### Basic Inference {#sec-basic}
::: {.panel-tabset}
## LoRA Models
```{.bash}
axolotl inference your_config.yml --lora-model-dir="./lora-output-dir"
```
## Full Fine-tuned Models
```{.bash}
axolotl inference your_config.yml --base-model="./completed-model"
```
:::
## Advanced Usage {#sec-advanced}
### Gradio Interface {#sec-gradio}
Launch an interactive web interface:
```{.bash}
axolotl inference your_config.yml --gradio
```
### File-based Prompts {#sec-file-prompts}
Process prompts from a text file:
```{.bash}
cat /tmp/prompt.txt | axolotl inference your_config.yml \
--base-model="./completed-model" --prompter=None
```
### Memory Optimization {#sec-memory}
For large models or limited memory:
```{.bash}
axolotl inference your_config.yml --load-in-8bit=True
```
## Merging LoRA Weights {#sec-merging}
Merge LoRA adapters with the base model:
```{.bash}
axolotl merge-lora your_config.yml --lora-model-dir="./completed-model"
```
### Memory Management for Merging {#sec-memory-management}
::: {.panel-tabset}
## Configuration Options
```{.yaml}
gpu_memory_limit: 20GiB # Adjust based on your GPU
lora_on_cpu: true # Process on CPU if needed
```
## Force CPU Merging
```{.bash}
CUDA_VISIBLE_DEVICES="" axolotl merge-lora ...
```
:::
## Tokenization {#sec-tokenization}
### Common Issues {#sec-tokenization-issues}
::: {.callout-warning}
Tokenization mismatches between training and inference are a common source of problems.
:::
To debug:
1. Check training tokenization:
```{.bash}
axolotl preprocess your_config.yml --debug
```
2. Verify inference tokenization by decoding tokens before model input
3. Compare token IDs between training and inference
### Special Tokens {#sec-special-tokens}
Configure special tokens in your YAML:
```{.yaml}
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
tokens:
- "<|im_start|>"
- "<|im_end|>"
```
## Troubleshooting {#sec-troubleshooting}
### Common Problems {#sec-common-problems}
::: {.panel-tabset}
## Memory Issues
- Use 8-bit loading
- Reduce batch sizes
- Try CPU offloading
## Token Issues
- Verify special tokens
- Check tokenizer settings
- Compare training and inference preprocessing
## Performance Issues
- Verify model loading
- Check prompt formatting
- Ensure temperature/sampling settings
:::
For more details, see our [debugging guide](debugging.qmd).

View File

@@ -1,119 +0,0 @@
---
title: "Installation Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers all the ways you can install and set up Axolotl for your environment.
## Requirements {#sec-requirements}
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
- Python ≥3.10
- PyTorch ≥2.4.1
## Installation Methods {#sec-installation-methods}
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
installed) in order not to clobber it, and so that we set the correct version of
dependencies that are specific to the PyTorch version or other installed
co-dependencies.
### Edge/Development Build {#sec-edge-build}
For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
### Docker {#sec-docker}
```{.bash}
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
```
For development with Docker:
```{.bash}
docker compose up -d
```
::: {.callout-tip}
### Advanced Docker Configuration
```{.bash}
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
--name axolotl --ipc=host \
--ulimit memlock=-1 --ulimit stack=67108864 \
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
axolotlai/axolotl:main-latest
```
:::
## Cloud Environments {#sec-cloud}
### Cloud GPU Providers {#sec-cloud-gpu}
For providers supporting Docker:
- Use `axolotlai/axolotl-cloud:main-latest`
- Available on:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
### Google Colab {#sec-colab}
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
## Platform-Specific Instructions {#sec-platform-specific}
### macOS {#sec-macos}
```{.bash}
pip3 install --no-build-isolation -e '.'
```
See @sec-troubleshooting for Mac-specific issues.
### Windows {#sec-windows}
::: {.callout-important}
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
:::
## Environment Managers {#sec-env-managers}
### Conda/Pip venv {#sec-conda}
1. Install Python ≥3.10
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:
```{.bash}
huggingface-cli login
```
## Troubleshooting {#sec-troubleshooting}
If you encounter installation issues, see our [FAQ](faq.qmd) and [Debugging Guide](debugging.qmd).

View File

@@ -1,118 +0,0 @@
---
title: "Multi-GPU Training Guide"
format:
html:
toc: true
toc-depth: 3
number-sections: true
code-tools: true
execute:
enabled: false
---
This guide covers advanced training configurations for multi-GPU setups using Axolotl.
## Overview {#sec-overview}
Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
### Configuration {#sec-deepspeed-config}
Add to your YAML config:
```{.yaml}
deepspeed: deepspeed_configs/zero1.json
```
### Usage {#sec-deepspeed-usage}
```{.bash}
accelerate launch -m axolotl.cli.train examples/llama-2/config.yml --deepspeed deepspeed_configs/zero1.json
```
### ZeRO Stages {#sec-zero-stages}
We provide default configurations for:
- ZeRO Stage 1 (`zero1.json`)
- ZeRO Stage 2 (`zero2.json`)
- ZeRO Stage 3 (`zero3.json`)
Choose based on your memory requirements and performance needs.
## FSDP {#sec-fsdp}
### Basic FSDP Configuration {#sec-fsdp-config}
```{.yaml}
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
## Performance Optimization {#sec-performance}
### Liger Kernel Integration {#sec-liger}
::: {.callout-note}
Liger Kernel provides efficient Triton kernels for LLM training, offering:
- 20% increase in multi-GPU training throughput
- 60% reduction in memory usage
- Compatibility with both FSDP and DeepSpeed
:::
Configuration:
```{.yaml}
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
```
## Troubleshooting {#sec-troubleshooting}
### NCCL Issues {#sec-nccl}
For NCCL-related problems, see our [NCCL troubleshooting guide](nccl.qmd).
### Common Problems {#sec-common-problems}
::: {.panel-tabset}
## Memory Issues
- Reduce `micro_batch_size`
- Reduce `eval_batch_size`
- Adjust `gradient_accumulation_steps`
- Consider using a higher ZeRO stage
## Training Instability
- Start with DeepSpeed ZeRO-2
- Monitor loss values
- Check learning rates
:::
For more detailed troubleshooting, see our [debugging guide](debugging.qmd).

View File

@@ -1,93 +0,0 @@
---
title: Ray Train integration
description: How to use Axolotl with Ray Train
---
Axolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.
With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training.
## Ray cluster setup
A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html
Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).
## Sanity check
To run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:
```bash
ray status
```
The output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:
```
Node status
---------------------------------------------------------------
Active:
1 head
Idle:
2 4xL40S:48CPU-384GB
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/96.0 CPU
0.0/8.0 GPU
0B/800.00GiB memory
0B/229.57GiB object_store_memory
Demands:
(no resource demands)
```
You should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html).
## Configuring training with Ray Train
You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.
The key parameters to note here are:
```yaml
...
use_ray: true
ray_num_workers: 4
# optional
resources_per_worker:
GPU: 1
...
```
- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
- `ray_num_workers`: This is the number of workers/GPUs to use for training.
- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do
```yaml
resources_per_worker:
accelerator_type:L40S: 0.001
```
## Launching training
You can simply run the following command on the head node:
```bash
axolotl train examples/llama-3/lora-1b-ray.yml --use-ray
```
This will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.
You can also monitor training progress on the Ray dashboard.
Coming back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following:
![Ray dashboard](./images/ray-cluster-dashboard.png)

View File

@@ -1,47 +0,0 @@
---
title: "Reward Modelling"
description: "Reward models are used to guide models towards behaviors which is preferred by humans, by training over large datasets annotated with human preferences. "
---
### Overview
Reward modelling is a technique used to train models to predict the reward or value of a given input. This is particularly useful in reinforcement learning scenarios where the model needs to evaluate the quality of its actions or predictions.
We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
```yaml
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.1
eval_steps: 100
```
### Process Reward Models (PRM)
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B
model_type: AutoModelForTokenClassification
num_labels: 2
process_reward_model: true
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
split: train
val_set_size: 0.1
eval_steps: 100
```

View File

@@ -46,7 +46,7 @@ output_dir: ./outputs/btlm-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.000000001
max_grad_norm: 1.0

View File

@@ -1,28 +0,0 @@
project_name:
volumes:
- name: axolotl-data
mount: /workspace/data
- name: axolotl-artifacts
mount: /workspace/artifacts
# environment variables from local to set as secrets
secrets:
- HF_TOKEN
- WANDB_API_KEY
# Which branch of axolotl to use remotely
branch:
# additional custom commands when building the image
dockerfile_commands:
gpu: h100
gpu_count: 1
# Train specific configurations
memory: 128
timeout: 86400
# Preprocess specific configurations
memory_preprocess: 32
timeout_preprocess: 14400

View File

@@ -27,7 +27,7 @@ wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -47,7 +47,7 @@ peft_use_rslora: true
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-2-2b
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

View File

@@ -34,7 +34,7 @@ lora_target_linear: false
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -42,7 +42,7 @@ output_dir: ./outputs/model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -37,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

View File

@@ -1,79 +0,0 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero3.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"
use_ray: true
ray_num_workers: 4

View File

@@ -30,7 +30,7 @@ lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001

View File

@@ -47,7 +47,7 @@ wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -41,7 +41,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -43,7 +43,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -39,7 +39,7 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 12
num_epochs: 2
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_fused
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0

View File

@@ -1,72 +0,0 @@
base_model: Qwen/Qwen2.5-3B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForTokenClassification
num_labels: 2
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
process_reward_model: true
chat_template:
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
step_separator: "\n"
max_completion_length:
train_on_last_step_only: false
val_set_size: 0.2
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
eval_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32:
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
eval_steps: 100
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -37,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -1,67 +0,0 @@
base_model: Qwen/Qwen2.5-0.5B
# optionally might have model_type or tokenizer_type
model_type: AutoModelForSequenceClassification
num_labels: 1
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
reward_model: true
chat_template: qwen_25
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
remove_unused_columns: false
sequence_len: 2048
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -38,7 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch_fused
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.1
bitsandbytes==0.45.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
@@ -25,7 +25,6 @@ hf_transfer
sentencepiece
gradio==3.50.2
modal==0.70.5
pydantic==2.6.3
addict
fire

View File

@@ -1,15 +1,10 @@
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:

View File

@@ -32,6 +32,8 @@ def parse_requirements():
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
@@ -85,8 +87,24 @@ def parse_requirements():
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.25.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
@@ -150,8 +168,5 @@ setup(
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
},
)

View File

@@ -25,8 +25,6 @@ class TrainerCliArgs:
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass

View File

@@ -1,56 +0,0 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
import yaml
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
return cloud_cfg
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.preprocess(config_yaml)
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -1,18 +0,0 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
class Cloud(ABC):
"""
Abstract base class for cloud platforms.
"""
@abstractmethod
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
pass

View File

@@ -1,282 +0,0 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
import modal
from axolotl.cli.cloud.base import Cloud
def run_cmd(cmd: str, run_folder: str, volumes=None):
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
# Ensure volumes contain latest files.
if volumes:
for _, vol in volumes.items():
vol.reload()
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code) # pylint: disable=consider-using-sys-exit
# Commit writes to volume.
if volumes:
for _, vol in volumes.items():
vol.commit()
class ModalCloud(Cloud):
"""
Modal Cloud implementation.
"""
def __init__(self, config, app=None):
self.config = config
if not app:
app = modal.App()
self.app = app
self.volumes = {}
if config.volumes:
for volume_config in config.volumes:
_, mount, vol = self.create_volume(volume_config)
self.volumes[mount] = (vol, volume_config)
def get_env(self):
res = {
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
}
for key in self.config.get("env", []):
if isinstance(key, str):
if val := os.environ.get(key, ""):
res[key] = val
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res[key_] = val
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:
sha256_hash = None
# create the image
if sha256_hash:
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
else:
image = modal.Image.from_registry(docker_image)
dockerfile_commands = []
if self.config.dockerfile_commands:
dockerfile_commands.extend(self.config.dockerfile_commands)
# branch
if self.config.branch:
dockerfile_commands.extend(
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
if dockerfile_commands:
image = image.dockerfile_commands(dockerfile_commands)
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res.append(modal.Secret.from_dict({key_: val}))
return res
def create_volume(self, volume_config):
name = volume_config.name
mount = volume_config.mount
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
def get_ephemeral_disk_size(self):
return 1000 * 525 # 1 TiB
def get_preprocess_timeout(self):
if self.config.timeout_preprocess:
return int(self.config.timeout_preprocess)
return 60 * 60 * 3 # 3 hours
def get_preprocess_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
if self.config.memory_preprocess:
memory = int(self.config.memory_preprocess)
return 1024 * memory
def get_preprocess_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=8.0,
ephemeral_disk=self.get_ephemeral_disk_size(),
memory=self.get_preprocess_memory(),
timeout=self.get_preprocess_timeout(),
secrets=self.get_secrets(),
)
def preprocess(self, config_yaml: str, *args, **kwargs):
modal_fn = self.get_preprocess_env()(_preprocess)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
**kwargs,
)
def get_train_timeout(self):
if self.config.timeout:
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
if family == "l40s":
return modal.gpu.L40S(count=count)
if family in ["a100", "a100-40gb"]:
return modal.gpu.A100(count=count, size="40GB")
if family == "a100-80gb":
return modal.gpu.A100(count=count, size="80GB")
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":
return modal.gpu.L4(count=count)
raise ValueError(f"Unsupported GPU family: {family}")
def get_train_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
memory=self.get_train_memory(),
timeout=self.get_train_timeout(),
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def lm_eval(self, config_yaml: str):
modal_fn = self.get_train_env()(_lm_eval)
with modal.enable_output():
with self.app.run(detach=True):
if self.config.get("spawn", False):
modal_fn_exec = modal_fn.spawn
else:
modal_fn_exec = modal_fn.remote
modal_fn_exec(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def _preprocess(config_yaml: str, volumes=None):
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
run_folder,
volumes,
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
if accelerate:
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)
def _lm_eval(config_yaml: str, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -19,7 +19,7 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> dict[str, float]:
"""
Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
@@ -39,7 +39,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -8,6 +8,7 @@ import click
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -15,7 +16,6 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -28,28 +28,21 @@ def cli():
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
def preprocess(config: str, **kwargs) -> None:
"""
Preprocess datasets before training.
Args:
config: Path to `axolotl` config YAML file.
cloud: Path to a cloud accelerator configuration file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
from axolotl.cli.preprocess import do_cli
do_cli_preprocess(cloud_config=cloud, config=config)
else:
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs)
do_cli(config=config, **kwargs)
@cli.command()
@@ -59,56 +52,32 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
default=True,
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
def train(config: str, accelerate: bool, **kwargs) -> None:
"""
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
do_cli(config=config, **kwargs)
@cli.command()
@@ -227,6 +196,7 @@ def merge_lora(config: str, **kwargs) -> None:
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
@@ -253,7 +223,7 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest)
cli.add_command(lm_eval)
setup_plugin_commands(cli)
def main():

View File

@@ -0,0 +1,36 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Union
import fire
from accelerate import Accelerator
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
@@ -65,47 +63,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
return_remaining_strings=True
)
if parsed_cfg.use_ray:
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
trainer = TorchTrainer(
ray_train_func,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=parsed_cfg.ray_num_workers,
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
use_gpu=True,
),
run_config=RunConfig(
name=parsed_cfg.ray_run_name,
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
),
)
return trainer.fit()
return do_train(parsed_cfg, parsed_cli_args)
def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
normalize_config(cfg)
# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
resolve_dtype(cfg)
# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed:
cfg.deepspeed = cfg.deepspeed.to_dict()
# initialize accelerator before model instantiation
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
kwargs["cfg"] = cfg
do_train(**kwargs)
do_train(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":

View File

@@ -157,6 +157,8 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else:
cmd.extend([f"--{key}", str(value)])

View File

@@ -44,8 +44,6 @@ from trl import (
KTOTrainer,
ORPOConfig,
ORPOTrainer,
PRMConfig,
PRMTrainer,
RewardConfig,
RewardTrainer,
)
@@ -299,7 +297,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin.
"""
@@ -344,13 +342,6 @@ class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
@dataclass
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
"""
PRM config for PRM training
"""
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
@@ -1253,14 +1244,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -1394,8 +1377,7 @@ class TrainerBuilderBase(abc.ABC):
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for causal models
and reward modelling using TRL.
Build the HuggingFace training args/trainer for Causal models
"""
def get_callbacks(self):
@@ -1470,8 +1452,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -1862,13 +1842,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"accelerator_config"
] = self.cfg.accelerator_config
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
training_args_cls = AxolotlTrainingArguments
training_args_cls = (
AxolotlTrainingArguments
if not self.cfg.reward_model
else AxolotlRewardConfig
)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
@@ -1902,9 +1880,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if eval_data_collator := self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
):
if not (self.cfg.reward_model or self.cfg.process_reward_model):
if not self.cfg.reward_model:
trainer_kwargs["eval_data_collator"] = eval_data_collator
if not (self.cfg.reward_model or self.cfg.process_reward_model):
if not self.cfg.reward_model:
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
@@ -1915,10 +1893,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["processing_class"] = self.tokenizer
else:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
and self.cfg.datasets is not None
):
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
@@ -2008,7 +1984,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
Trainer factory class for DPO Trainer
"""
def get_callbacks(self):

View File

@@ -52,7 +52,6 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,

View File

@@ -4,7 +4,7 @@ import csv
import os
import sys
from pathlib import Path
from typing import Dict, Optional
from typing import Optional
import torch
from accelerate.logging import get_logger
@@ -26,7 +26,7 @@ LOG = get_logger("axolotl.evaluate")
def evaluate_dataset(
trainer, dataset, dataset_type: str, flash_optimum: bool = False
) -> Optional[Dict[str, float]]:
) -> Optional[dict[str, float]]:
"""Helper function to evaluate a single dataset safely.
Args:
@@ -61,7 +61,7 @@ def evaluate_dataset(
return metrics
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> dict[str, float]:
"""
Evaluate a model on training and validation datasets

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -2,9 +2,9 @@
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
@@ -18,20 +18,25 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -13,5 +13,3 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True
lm_eval_model: Optional[str] = None

View File

@@ -1,119 +0,0 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime
from typing import Optional
import click
import yaml
from axolotl.utils.dict import DictDefault
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
flash_attention=False,
output_dir="./",
batch_size=8,
wandb_project=None,
wandb_entity=None,
wandb_name=None,
model=None,
revision=None,
apply_chat_template=None,
fewshot_as_multiturn=None,
):
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
if isinstance(tasks, str):
tasks = [tasks]
for task in tasks:
num_fewshot = "-1"
task_parts = task.split(":")
task_name = task_parts[0]
if len(task_parts) == 2:
task_name, num_fewshot = task_parts
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
tasks_str = ",".join(tasks_list)
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
pretrained = "pretrained="
pretrained += model if model else output_dir
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
revision = f",revision={revision}" if revision else ""
output_path = output_dir
output_path += "" if output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
lm_eval_args = [
"lm_eval",
"--model",
"hf",
"--model_args",
f"{pretrained}{fa2}{dtype}{revision}",
"--tasks",
tasks_str,
"--batch_size",
str(batch_size),
"--output_path",
output_path,
]
wandb_args = []
if wandb_project:
wandb_args.append(f"project={wandb_project}")
if wandb_entity:
wandb_args.append(f"entity={wandb_entity}")
if wandb_name:
wandb_args.append(f"name={wandb_name}")
if wandb_args:
lm_eval_args.append("--wandb_args")
lm_eval_args.append(",".join(wandb_args))
if apply_chat_template:
lm_eval_args.append("--apply_chat_template")
if num_fewshot_val:
lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val))
if apply_chat_template and fewshot_as_multiturn:
lm_eval_args.append("--fewshot_as_multiturn")
yield lm_eval_args
@click.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
def lm_eval(config: str, cloud: Optional[str] = None):
"""
use lm eval to evaluate a trained language model
"""
if cloud:
from axolotl.cli.cloud import do_cli_lm_eval
do_cli_lm_eval(cloud_config=cloud, config=config)
else:
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)

View File

@@ -1,116 +0,0 @@
"""
Module for stepwise datasets, typically including a prompt and reasoning traces,
and (optionally) per-step, or per-prompt-trace labels for reward modelling.
"""
from itertools import chain
from typing import Dict, List, Optional, Union
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompt_tokenizers import IGNORE_INDEX
from axolotl.utils.dict import DictDefault
class StepwiseSupervisedPromptTokenizingStrategy:
"""
Tokenizing strategy for supervised stepwise datasets, typically used for COT-reasoning.
These datasets should include the following columns:
- prompt: the prompt text
- completions: a list of `n` completion steps
- labels: a list of `n` labels indicating the "correctness" of each step
"""
def __init__(
self,
tokenizer,
sequence_len: int = 2048,
step_separator: str = "\n",
max_completion_length: Optional[int] = None,
train_on_last_step_only: bool = False,
):
self.tokenizer = tokenizer
self.sequence_len = sequence_len
self.step_separator = step_separator
self.max_completion_length = max_completion_length
self.train_on_last_step_only = train_on_last_step_only
def tokenize_prompt(
self, prompt: Dict[str, Union[str, List[str]]]
) -> BatchEncoding:
# Inspired by TRL's PRMTRainer
# https://github.com/huggingface/trl/blob/ed7de87dc766478c024b68f12530d1b0e7c3ff23/trl/trainer/prm_trainer.py#L206
prompt_ids = self.tokenizer(prompt["prompt"], add_special_tokens=False)[
"input_ids"
]
completions_ids = [
self.tokenizer(completion, add_special_tokens=False)["input_ids"]
for completion in prompt["completions"]
]
# Handle labels
if self.train_on_last_step_only:
labels = [IGNORE_INDEX] * (len(prompt["labels"]) - 1) + [
int(prompt["labels"][-1])
]
else:
labels = [int(label) for label in prompt["labels"]]
# Add step separators
separator_ids = self.tokenizer.encode(
self.step_separator, add_special_tokens=False
)
completions_ids = [completion + separator_ids for completion in completions_ids]
# Create step-wise labels
labels = [
[IGNORE_INDEX] * (len(completion) - 1) + [label] # type: ignore
for completion, label in zip(completions_ids, labels)
]
# Join all steps
completion_ids = list(chain(*completions_ids))
labels = list(chain(*labels)) # type: ignore
# Handle max lengths
if self.max_completion_length:
completion_ids = completion_ids[: self.max_completion_length]
labels = labels[: self.max_completion_length]
# Add BOS token if model has one
if self.tokenizer.bos_token_id is not None:
prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids
# Combine prompt and completion
input_ids = prompt_ids + completion_ids
full_labels = [IGNORE_INDEX] * len(prompt_ids) + labels
# Apply max sequence length
if self.sequence_len:
input_ids = input_ids[: self.sequence_len]
full_labels = full_labels[: self.sequence_len]
return {
"input_ids": input_ids,
"labels": full_labels,
"attention_mask": [1] * len(input_ids),
}
@property
def supports_batched(self):
return False
def load(
tokenizer: PreTrainedTokenizer,
cfg: DictDefault,
ds_cfg: DictDefault,
) -> StepwiseSupervisedPromptTokenizingStrategy:
return StepwiseSupervisedPromptTokenizingStrategy(
tokenizer,
cfg.sequence_len,
step_separator=ds_cfg.get("step_separator", "\n"),
max_completion_length=ds_cfg.max_completion_length,
train_on_last_step_only=ds_cfg.get("train_on_last_step_only", False),
)

View File

@@ -141,9 +141,7 @@ def train(
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal
if cfg.local_rank == 0:
def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
@@ -261,7 +259,7 @@ def train(
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model or cfg.process_reward_model:
if cfg.rl is not None or cfg.reward_model:
dataset_tags = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import gc
import logging
import math
import os
import traceback
from shutil import copyfile
@@ -829,6 +830,13 @@ class SaveModelCallback(TrainerCallback):
# Save
if state.global_step >= state.max_steps:
control.should_save = True
elif (
args.save_strategy == IntervalStrategy.STEPS
and state.save_steps < 1.0
and state.global_step % math.ceil(state.save_steps * state.max_steps) == 0
):
# workaround to save model on fractional save_steps
control.should_save = True
def on_train_end( # pylint: disable=unused-argument
self, args, state, control, **kwargs

View File

@@ -1,5 +1,4 @@
"""Module for working with config dicts"""
import json
import logging
import os
from typing import Optional
@@ -57,38 +56,6 @@ def choose_device(cfg):
cfg.device_map = None
def resolve_dtype(cfg):
if (
cfg.bf16 == "auto" and not cfg.use_ray
): # if we use ray we want to defer this check to the worker node
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
cfg.bf16 = True
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.fp16 is None:
cfg.fp16 = True
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16:
cfg.fp16 = False
if cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16
else:
cfg.torch_dtype = torch.float32
def normalize_config(cfg):
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
@@ -115,15 +82,33 @@ def normalize_config(cfg):
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
if not cfg.use_ray:
# delay resolving dtype until on worker node when launching with ray
resolve_dtype(cfg)
if cfg.bf16 == "auto":
if is_torch_bf16_gpu_available():
LOG.debug("bf16 support detected, enabling for this configuration.")
cfg.bf16 = True
else:
LOG.debug("bf16 support not detected, disabling for this configuration.")
cfg.bf16 = False
if cfg.fp16 is None:
cfg.fp16 = True
if cfg.deepspeed:
if isinstance(cfg.deepspeed, str) and os.path.exists(cfg.deepspeed):
ds_config_path = cfg.deepspeed
with open(ds_config_path, encoding="utf-8") as f:
cfg.deepspeed = json.load(f)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16:
cfg.fp16 = False
if cfg.bf16 or cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
cfg.torch_dtype = torch.float16
else:
cfg.torch_dtype = torch.float32
if cfg.saves_per_epoch:
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)

View File

@@ -236,18 +236,6 @@ class DPODataset(BaseModel):
revision: Optional[str] = None
class StepwiseSupervisedDataset(BaseModel):
"""Stepwise supervised dataset configuration subset"""
path: Optional[str] = None
split: Optional[str] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
step_separator: Optional[str] = None
max_completion_length: Optional[int] = None
train_on_last_step_only: Optional[bool] = None
class UserDefinedKTOType(BaseModel):
"""User defined typing for KTO"""
@@ -501,7 +489,7 @@ class HyperparametersConfig(BaseModel):
adam_beta1: Optional[float] = None
adam_beta2: Optional[float] = None
max_grad_norm: Optional[float] = None
num_epochs: float = Field(default=1.0)
num_epochs: int = Field(default=1)
@field_validator("batch_size")
@classmethod
@@ -528,7 +516,7 @@ class ModelOutputConfig(BaseModel):
output_dir: str = Field(default="./model-out")
hub_model_id: Optional[str] = None
hub_strategy: Optional[str] = None
save_safetensors: Optional[bool] = True
save_safetensors: Optional[bool] = None
class MLFlowConfig(BaseModel):
@@ -607,30 +595,6 @@ class GradioConfig(BaseModel):
gradio_temperature: Optional[float] = None
class RayConfig(BaseModel):
"""Ray launcher configuration subset"""
use_ray: bool = Field(default=False)
ray_run_name: Optional[str] = Field(
default=None,
metadata={
"help": "The training results will be saved at `saves/ray_run_name`."
},
)
ray_num_workers: int = Field(
default=1,
metadata={
"help": "The number of workers for Ray training. Default is 1 worker."
},
)
resources_per_worker: dict = Field(
default_factory=lambda: {"GPU": 1},
metadata={
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
},
)
# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
@@ -643,7 +607,6 @@ class AxolotlInputConfig(
CometConfig,
LISAConfig,
GradioConfig,
RayConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
@@ -663,14 +626,12 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
process_reward_model: Optional[bool] = None
num_labels: Optional[int] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None

View File

@@ -8,8 +8,6 @@ from typing import List, Tuple, Union
from datasets import (
Dataset,
DatasetDict,
Sequence,
Value,
concatenate_datasets,
load_dataset,
load_from_disk,
@@ -469,17 +467,6 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
elif config_dataset.type.startswith("stepwise_supervised"):
dataset_prompter = UnsupportedPrompter()
ds_strategy = load(config_dataset.type, tokenizer, cfg, config_dataset)
# we need to explicitly cast boolean labels to int
# for compatibility with how trl's PRMTrainer works
dataset = dataset.cast_column("labels", Sequence(Value("int64")))
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):

View File

@@ -138,9 +138,7 @@ def load_model_config(cfg):
config_kwargs = {}
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model
if cfg.num_labels:
# num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels
try:
model_config = AutoConfig.from_pretrained(
model_config_name,
@@ -814,6 +812,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,

View File

@@ -374,7 +374,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
if cfg.sample_packing_eff_est:
total_num_steps = (
# match count to len est in dataloader
int(
(
math.floor(
0.99
* cfg.total_num_tokens

157
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,157 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -1,5 +1 @@
/* css styles */
img[alt="Axolotl"] {
content: url("https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg") !important;
}

View File

@@ -43,14 +43,12 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -23,6 +23,7 @@ def test_build_command():
"--batch-size",
"8",
"--debug",
"--nouse-fp16",
]

View File

@@ -22,7 +22,7 @@ def fixture_cfg():
"output_dir": "./model-out",
"warmup_steps": 10,
"gradient_checkpointing": False,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.998,

View File

@@ -39,7 +39,7 @@ def min_cfg(temp_dir):
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"output_dir": temp_dir,
"lr_scheduler": "cosine",
"save_safetensors": True,

View File

@@ -48,7 +48,7 @@ class LigerIntegrationTestCase:
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
@@ -93,7 +93,7 @@ class LigerIntegrationTestCase:
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",

View File

@@ -74,13 +74,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -137,13 +139,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -210,13 +214,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -287,13 +293,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -331,7 +339,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp": [
@@ -359,13 +367,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -401,7 +411,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp": [
@@ -429,13 +439,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -480,7 +492,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp": [
@@ -508,13 +520,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -575,7 +589,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
@@ -591,13 +605,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -648,7 +664,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
@@ -664,13 +680,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@@ -721,7 +739,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
@@ -737,13 +755,15 @@ class TestMultiGPULlama:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -52,7 +52,7 @@ class TestMultiGPUQwen2:
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": "auto",
@@ -86,12 +86,14 @@ class TestMultiGPUQwen2:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"accelerate",
"launch",
"--num-processes",
"2",
"--main-process-port",
"--main_process_port",
f"{get_torch_dist_unique_port()}",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -1,137 +0,0 @@
"""
E2E tests for multigpu post-training use Ray Train
"""
import logging
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
class TestMultiGPURay:
"""
Test cases for AnyScale Ray post training
"""
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"use_ray": True,
"ray_num_workers": 2,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--use-ray",
"--ray-num-workers",
"2",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--use-ray",
"--ray-num-workers",
"2",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)

View File

@@ -12,7 +12,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -23,6 +23,7 @@ class Test4dMultipackLlama(unittest.TestCase):
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_3_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
@@ -52,7 +53,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
@@ -96,7 +97,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,

View File

@@ -56,7 +56,7 @@ class TestFusedLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"save_steps": 5,

View File

@@ -56,7 +56,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"save_steps": 5,
@@ -96,7 +96,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 10,
"save_steps": 5,

View File

@@ -61,7 +61,7 @@ class TestLoraLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
@@ -116,7 +116,7 @@ class TestLoraLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)

View File

@@ -55,7 +55,7 @@ class TestMistral(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
@@ -96,7 +96,7 @@ class TestMistral(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,

View File

@@ -48,7 +48,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"embedding_lr_scale": 0.5,
"lr_scheduler": "cosine",
"save_safetensors": True,
@@ -92,7 +92,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"embedding_lr": 0.000005,
"lr_scheduler": "cosine",
"save_safetensors": True,

View File

@@ -57,7 +57,7 @@ class TestFalcon(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
@@ -110,7 +110,7 @@ class TestFalcon(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
@@ -149,7 +149,7 @@ class TestFalcon(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,

View File

@@ -62,7 +62,7 @@ class TestPretrainLlama:
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",

View File

@@ -52,7 +52,7 @@ class TestLoadModelUtils:
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)

View File

@@ -54,7 +54,7 @@ class TestLoraLlama(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
}

View File

@@ -51,7 +51,7 @@ class TestMamba(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,

View File

@@ -56,7 +56,7 @@ class TestMistral(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
@@ -95,7 +95,7 @@ class TestMistral(unittest.TestCase):
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,

Some files were not shown because too many files have changed in this diff Show More