Compare commits
78 Commits
custom-mod
...
diffusion-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64f349b7bb | ||
|
|
260ebe4c93 | ||
|
|
63d2280999 | ||
|
|
b210db2d15 | ||
|
|
556a69118f | ||
|
|
8569675b26 | ||
|
|
c10eb811fa | ||
|
|
0eef385b1a | ||
|
|
077b5a4358 | ||
|
|
ecbe8b2b61 | ||
|
|
234b7b3126 | ||
|
|
130ef7c51a | ||
|
|
e19be0c2d9 | ||
|
|
479a454ae3 | ||
|
|
0a9341acde | ||
|
|
d8b63804bc | ||
|
|
3156c605d4 | ||
|
|
d1de6f5f3d | ||
|
|
48b7ae1677 | ||
|
|
506e3a3907 | ||
|
|
09145de8fa | ||
|
|
e0a2523a3b | ||
|
|
3d45620008 | ||
|
|
ce20e838b5 | ||
|
|
d4d84d48af | ||
|
|
9b12c05660 | ||
|
|
686933194e | ||
|
|
d12b461d19 | ||
|
|
d6b81b3683 | ||
|
|
05f1b4b2e8 | ||
|
|
7cfc80ec77 | ||
|
|
0da6a95efa | ||
|
|
2c8497e489 | ||
|
|
f70d4de8c7 | ||
|
|
0ae06d756d | ||
|
|
2974670bf8 | ||
|
|
50f2b94d50 | ||
|
|
eb2c87b525 | ||
|
|
4db7f023c6 | ||
|
|
4273d5cf7e | ||
|
|
c5e5aba547 | ||
|
|
9d5c95db6f | ||
|
|
ca796fb56e | ||
|
|
597953bef0 | ||
|
|
39fbd3b2b5 | ||
|
|
46dfacf255 | ||
|
|
4bce713b39 | ||
|
|
d09290f2f4 | ||
|
|
e442ff22aa | ||
|
|
ba3dba3e4f | ||
|
|
97e86c6d47 | ||
|
|
784f8c0e95 | ||
|
|
e3177c3210 | ||
|
|
70faea331f | ||
|
|
8021c718ce | ||
|
|
42f5e6f9e9 | ||
|
|
ab49d16e34 | ||
|
|
33d094721c | ||
|
|
a54c1be972 | ||
|
|
5691992d34 | ||
|
|
e758343cac | ||
|
|
deac7b18a1 | ||
|
|
10946afae7 | ||
|
|
5639552064 | ||
|
|
cda3c82351 | ||
|
|
7c3b428f23 | ||
|
|
01a6bd1a0e | ||
|
|
41709822a7 | ||
|
|
02a37199ee | ||
|
|
7026cd5e9e | ||
|
|
eb0a8a7775 | ||
|
|
294c7fe7a6 | ||
|
|
7b68dfafd7 | ||
|
|
32a7890231 | ||
|
|
563f5eed7a | ||
|
|
6ec282094d | ||
|
|
09dda462ab | ||
|
|
bb1cae1a20 |
7
.github/CONTRIBUTING.md
vendored
7
.github/CONTRIBUTING.md
vendored
@@ -57,6 +57,13 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
|
||||
5. Push your branch to your fork on GitHub.
|
||||
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
||||
|
||||
#### Skipping CI Checks
|
||||
|
||||
You can skip certain CI checks by including specific keywords in your commit messages:
|
||||
|
||||
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
|
||||
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
|
||||
|
||||
## Style Guidelines
|
||||
|
||||
### Code Style
|
||||
|
||||
27
.github/workflows/base.yml
vendored
27
.github/workflows/base.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.6.3
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
@@ -64,9 +64,16 @@ jobs:
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base-nightly"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.11"
|
||||
# pytorch: nightly
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base-nightly"
|
||||
# # "next" is for release candidates of pytorch
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
@@ -122,6 +129,13 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -129,6 +143,13 @@ jobs:
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
23
.github/workflows/main.yml
vendored
23
.github/workflows/main.yml
vendored
@@ -24,12 +24,13 @@ jobs:
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -97,6 +98,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -150,6 +157,18 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
4
.github/workflows/preview-docs.yml
vendored
4
.github/workflows/preview-docs.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
|
||||
- name: Netlify Publish
|
||||
uses: nwtgck/actions-netlify@v3.0
|
||||
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
id: netlify
|
||||
with:
|
||||
publish-dir: './_site'
|
||||
@@ -68,7 +68,7 @@ jobs:
|
||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||
|
||||
- name: Update PR with preview link
|
||||
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
|
||||
if: ${{ steps.netlify.outcome == 'success' }}
|
||||
uses: marocchino/sticky-pull-request-comment@v2
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
49
.github/workflows/tests.yml
vendored
49
.github/workflows/tests.yml
vendored
@@ -105,7 +105,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
@@ -179,21 +180,52 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
skip: ${{ steps.compute.outputs.skip }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
id: compute
|
||||
with:
|
||||
script: |
|
||||
const token = /\[skip-e2e\]/i;
|
||||
let msg = '';
|
||||
if (context.eventName === 'push') {
|
||||
msg = context.payload.head_commit?.message || '';
|
||||
} else if (context.eventName === 'pull_request') {
|
||||
const { owner, repo } = context.repo;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const commits = await github.paginate(
|
||||
github.rest.pulls.listCommits,
|
||||
{ owner, repo, pull_number: prNumber, per_page: 100 }
|
||||
);
|
||||
msg = commits.at(-1)?.commit?.message || '';
|
||||
}
|
||||
const title = context.payload.pull_request?.title || '';
|
||||
const body = context.payload.pull_request?.body || '';
|
||||
const skip = token.test(msg) || token.test(title) || token.test(body);
|
||||
core.setOutput('skip', String(skip));
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
# Run this job first as a gate for running the remainder of the test matrix
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||
if: >
|
||||
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -239,13 +271,16 @@ jobs:
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||
if: >
|
||||
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
# Only run the remainder of the matrix if the first e2e check passed;
|
||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
@@ -23,11 +23,11 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.7
|
||||
rev: v3.3.8
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.0
|
||||
rev: v1.17.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -185,7 +185,6 @@ datasets:
|
||||
| `flash_attention` | `false` | Use flash attention |
|
||||
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
||||
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
||||
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
|
||||
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
||||
| `sdp_attention` | `false` | Use scaled dot product |
|
||||
| `s2_attention` | `false` | Use shifted sparse attention |
|
||||
|
||||
@@ -296,7 +296,6 @@
|
||||
# flash_attention:
|
||||
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# # Whether to use scaled-dot-product attention
|
||||
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
@@ -541,7 +540,6 @@ xformers_attention: ${XFORMERS_ATTENTION}
|
||||
flash_attention: ${FLASH_ATTENTION}
|
||||
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
||||
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
||||
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
|
||||
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
||||
sdp_attention: ${SDP_ATTENTION}
|
||||
s2_attention: ${S2_ATTENTION}
|
||||
|
||||
10
CITATION.cff
Normal file
10
CITATION.cff
Normal file
@@ -0,0 +1,10 @@
|
||||
cff-version: 1.2.0
|
||||
type: software
|
||||
title: "Axolotl: Post-Training for AI Models"
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "Axolotl maintainers and contributors"
|
||||
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
|
||||
url: "https://axolotl.ai/"
|
||||
license: Apache-2.0
|
||||
date-released: "2023-05-30"
|
||||
33
README.md
33
README.md
@@ -25,17 +25,28 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
|
||||
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||
|
||||
</details>
|
||||
|
||||
## ✨ Overview
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
@@ -138,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
|
||||
|
||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||
|
||||
## 📝 Citing Axolotl
|
||||
|
||||
If you use Axolotl in your research or projects, please cite it as follows:
|
||||
|
||||
```bibtex
|
||||
@software{axolotl,
|
||||
title = {Axolotl: Post-Training for AI Models},
|
||||
author = {{Axolotl maintainers and contributors}},
|
||||
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
||||
license = {Apache-2.0},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 📜 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
10
TODO.md
10
TODO.md
@@ -1,10 +0,0 @@
|
||||
# todo list
|
||||
|
||||
- [] Validation of parameters for combinations that won't work
|
||||
|
||||
|
||||
|
||||
## things that are known not to work
|
||||
|
||||
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
|
||||
- adamw_bnb_8bit doesn't play well with FSDP offload
|
||||
16
_quarto.yml
16
_quarto.yml
@@ -35,25 +35,30 @@ quartodoc:
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.art
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.delinearize_llama4
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.quantize
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- cli.quantize
|
||||
- cli.utils
|
||||
- cli.utils.args
|
||||
- cli.utils.fetch
|
||||
- cli.utils.load
|
||||
- cli.utils.sweeps
|
||||
- cli.utils.train
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.mamba
|
||||
- core.trainers.relora
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- core.trainers.grpo.sampler
|
||||
@@ -269,7 +274,7 @@ website:
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/gradient_accumulation.qmd
|
||||
- docs/optimizers.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
@@ -279,6 +284,7 @@ website:
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/nd_parallelism.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
set -e
|
||||
|
||||
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
|
||||
pytest -v -n2 \
|
||||
pytest -v --durations=10 -n2 \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
|
||||
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
/workspace/axolotl/tests/e2e/multigpu/ \
|
||||
|
||||
@@ -65,6 +65,9 @@ GPU_CONFIG = f"L40S:{N_GPUS}"
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
sp_env = os.environ.copy()
|
||||
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
@@ -16,7 +16,10 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
|
||||
ibverbs-providers ibverbs-utils infiniband-diags \
|
||||
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
|
||||
&& rm -rf /var/cache/apt/archives \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
@@ -34,7 +37,7 @@ WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ COPY scripts/motd /etc/motd
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
mkdir -p ~/.ssh && \
|
||||
|
||||
26
docs/cli.qmd
26
docs/cli.qmd
@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Launcher Arguments
|
||||
|
||||
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
|
||||
|
||||
```bash
|
||||
# Pass torchrun arguments
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
|
||||
# Pass accelerate arguments
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
|
||||
```
|
||||
|
||||
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
|
||||
|
||||
## Command Reference
|
||||
|
||||
### fetch
|
||||
@@ -80,7 +94,11 @@ axolotl train config.yml \
|
||||
--num-epochs 3
|
||||
|
||||
# Training without accelerate
|
||||
axolotl train config.yml --no-accelerate
|
||||
axolotl train config.yml --launcher python
|
||||
|
||||
# Pass launcher-specific arguments using -- separator
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
|
||||
|
||||
# Resume training from checkpoint
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl evaluate config.yml
|
||||
|
||||
# Evaluation with launcher arguments
|
||||
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
|
||||
```
|
||||
|
||||
### lm-eval
|
||||
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
|
||||
# Train on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train without accelerate on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
|
||||
# Run lm-eval on cloud
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
@@ -212,10 +212,11 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: ...
|
||||
- path: Nanobit/text-tools-2k-test
|
||||
type: chat_template
|
||||
# field_tools: tools # default is `tools`
|
||||
```
|
||||
|
||||
@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
|
||||
|
||||
Run the following on each node:
|
||||
|
||||
### Option 1: New Axolotl CLI with launcher args (Recommended)
|
||||
|
||||
```bash
|
||||
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
|
||||
```
|
||||
|
||||
### Option 2: Direct torchrun (Legacy)
|
||||
|
||||
```bash
|
||||
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
Please make sure to substitute the placeholder variables.
|
||||
Please make sure to substitute the placeholder variables:
|
||||
|
||||
- `num_nodes`: Number of nodes (containing GPUs)
|
||||
- `gpu_per_node`: Number of gpus per node
|
||||
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
|
||||
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
|
||||
- `rdzv_id`: A unique job ID that is used by the job across nodes.
|
||||
|
||||
::: {.callout-note}
|
||||
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
|
||||
:::
|
||||
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
|
||||
|
||||
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)
|
||||
|
||||
@@ -13,10 +13,13 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||
sample_packing: false # not yet supported with multimodal
|
||||
|
||||
chat_template: # see in next section
|
||||
chat_template: # see in next section if specified
|
||||
|
||||
# example dataset
|
||||
datasets:
|
||||
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
chat_template: mistral_v7_tekken
|
||||
```
|
||||
|
||||
### Voxtral {#sec-voxtral}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Voxtral-Mini-3B-2507
|
||||
```
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
::: {.callout-tip}
|
||||
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
|
||||
```
|
||||
|
||||
### LFM2-VL {#sec-lfm2-vl}
|
||||
|
||||
::: {.callout-warning}
|
||||
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
|
||||
|
||||
:::
|
||||
|
||||
### Video
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
This is not well tested at the moment. We welcome contributors!
|
||||
|
||||
:::
|
||||
|
||||
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
|
||||
|
||||
- `"path": "/path/to/video.mp4"`
|
||||
- `"url": "https://example.com/video.mp4"`
|
||||
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
|
||||
|
||||
### Example
|
||||
|
||||
Here is an example of a multi-modal dataset:
|
||||
|
||||
108
docs/nd_parallelism.qmd
Normal file
108
docs/nd_parallelism.qmd
Normal file
@@ -0,0 +1,108 @@
|
||||
---
|
||||
title: "N-D Parallelism (Beta)"
|
||||
---
|
||||
|
||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||
|
||||
- A model's weights are too large to fit on a single GPU's memory.
|
||||
- A model's activations, especially with very long contexts, are too large for a single GPU.
|
||||
- You want to accelerate training by using multiple GPUs or nodes.
|
||||
|
||||
or combinations of the above!
|
||||
|
||||
## Core Concepts
|
||||
|
||||
Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.
|
||||
|
||||
### Data Parallelism {#sec-dp}
|
||||
|
||||
Data Parallelism focuses on splitting the global data batch across GPUs.
|
||||
|
||||
- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.
|
||||
|
||||
- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).
|
||||
- FSDP maps to ZeRO stages:
|
||||
- ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.
|
||||
- ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).
|
||||
|
||||
### [Experimental] Tensor Parallelism (TP) {#sec-tp}
|
||||
|
||||
Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.
|
||||
|
||||
- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.
|
||||
- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.
|
||||
|
||||
### Context Parallelism (CP) {#sec-cp}
|
||||
|
||||
Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.
|
||||
|
||||
- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.
|
||||
- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this.
|
||||
- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.
|
||||
|
||||
### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}
|
||||
|
||||
HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.
|
||||
|
||||
- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.
|
||||
- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
|
||||
- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).
|
||||
|
||||
## Usage
|
||||
|
||||
```yaml
|
||||
# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
# ...
|
||||
|
||||
# The number of GPUs to shard the model parameters across (FSDP dimension).
|
||||
dp_shard_size: 4
|
||||
|
||||
# The number of times to replicate the sharded model (DDP dimension).
|
||||
dp_replicate_size: 2
|
||||
|
||||
# Number of GPUs for Tensor Parallelism.
|
||||
tensor_parallel_size: 1 # (default is 1, no TP)
|
||||
|
||||
# Number of GPUs for Context/Sequence Parallelism.
|
||||
context_parallel_size: 1 # (default is 1, no CP)
|
||||
```
|
||||
|
||||
Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.
|
||||
|
||||
## Examples
|
||||
|
||||
::: {.callout-tip}
|
||||
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
|
||||
:::
|
||||
|
||||
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
||||
- You want FSDP within each node and DDP across nodes.
|
||||
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
||||
|
||||
2. FSDP + TP on a single 8-GPU node:
|
||||
- You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
|
||||
- Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.
|
||||
|
||||
3. FSDP + CP on a single 8-GPU node for long context:
|
||||
- You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
|
||||
- Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.
|
||||
|
||||
## Support Matrix
|
||||
|
||||
This matrix describes how different parallelism methods can be combined in Axolotl.
|
||||
|
||||
| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |
|
||||
| --- | :---: | :---: |:---:|:---:|---|
|
||||
| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |
|
||||
| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |
|
||||
| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |
|
||||
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
||||
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
||||
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
|
||||
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
||||
|
||||
- `tp_size` refers to `tensor_parallel_size`
|
||||
- `cp_size` refers to `context_parallel_size`
|
||||
129
docs/optimizers.qmd
Normal file
129
docs/optimizers.qmd
Normal file
@@ -0,0 +1,129 @@
|
||||
---
|
||||
title: Optimizers
|
||||
description: Configuring optimizers
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
|
||||
|
||||
Here is a list of optimizers supported by transformers as of `v4.54.0`:
|
||||
|
||||
- `adamw_torch`
|
||||
- `adamw_torch_fused`
|
||||
- `adamw_torch_xla`
|
||||
- `adamw_torch_npu_fused`
|
||||
- `adamw_apex_fused`
|
||||
- `adafactor`
|
||||
- `adamw_anyprecision`
|
||||
- `adamw_torch_4bit`
|
||||
- `adamw_torch_8bit`
|
||||
- `ademamix`
|
||||
- `sgd`
|
||||
- `adagrad`
|
||||
- `adamw_bnb_8bit`
|
||||
- `adamw_8bit` # alias for adamw_bnb_8bit
|
||||
- `ademamix_8bit`
|
||||
- `lion_8bit`
|
||||
- `lion_32bit`
|
||||
- `paged_adamw_32bit`
|
||||
- `paged_adamw_8bit`
|
||||
- `paged_ademamix_32bit`
|
||||
- `paged_ademamix_8bit`
|
||||
- `paged_lion_32bit`
|
||||
- `paged_lion_8bit`
|
||||
- `rmsprop`
|
||||
- `rmsprop_bnb`
|
||||
- `rmsprop_bnb_8bit`
|
||||
- `rmsprop_bnb_32bit`
|
||||
- `galore_adamw`
|
||||
- `galore_adamw_8bit`
|
||||
- `galore_adafactor`
|
||||
- `galore_adamw_layerwise`
|
||||
- `galore_adamw_8bit_layerwise`
|
||||
- `galore_adafactor_layerwise`
|
||||
- `lomo`
|
||||
- `adalomo`
|
||||
- `grokadamw`
|
||||
- `schedule_free_radam`
|
||||
- `schedule_free_adamw`
|
||||
- `schedule_free_sgd`
|
||||
- `apollo_adamw`
|
||||
- `apollo_adamw_layerwise`
|
||||
- `stable_adamw`
|
||||
|
||||
|
||||
## Custom Optimizers
|
||||
|
||||
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
|
||||
|
||||
### optimi_adamw
|
||||
|
||||
```yaml
|
||||
optimizer: optimi_adamw
|
||||
```
|
||||
|
||||
### ao_adamw_4bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_4bit`.
|
||||
|
||||
### ao_adamw_8bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_8bit`.
|
||||
|
||||
### ao_adamw_fp8
|
||||
|
||||
|
||||
```yaml
|
||||
optimizer: ao_adamw_fp8
|
||||
```
|
||||
|
||||
### adopt_adamw
|
||||
|
||||
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
|
||||
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
|
||||
|
||||
```yaml
|
||||
optimizer: adopt_adamw
|
||||
```
|
||||
|
||||
### came_pytorch
|
||||
|
||||
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
|
||||
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
|
||||
|
||||
```yaml
|
||||
optimizer: came_pytorch
|
||||
|
||||
# optional args (defaults below)
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.999
|
||||
adam_beta3: 0.9999
|
||||
adam_epsilon: 1e-30
|
||||
adam_epsilon2: 1e-16
|
||||
```
|
||||
|
||||
### muon
|
||||
|
||||
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
|
||||
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
|
||||
|
||||
```yaml
|
||||
optimizer: muon
|
||||
```
|
||||
|
||||
### dion
|
||||
|
||||
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
||||
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
||||
|
||||
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
|
||||
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
|
||||
Note: Implementation written for PyTorch 2.7+ for DTensor
|
||||
|
||||
```yaml
|
||||
optimizer: dion
|
||||
dion_lr: 0.01
|
||||
dion_momentum: 0.95
|
||||
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
|
||||
```
|
||||
@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
# Set to a divisor (> 1) of the number of GPUs available
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
context_parallel_size: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -30,7 +30,7 @@ heads_k_stride: 1
|
||||
ring_attn_func:
|
||||
```
|
||||
|
||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||
- With 4 GPUs, valid values would be 2 or 4
|
||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
||||
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
||||
|
||||
## Effect on Batch Size
|
||||
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
||||
|
||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
||||
- The number of batches processed per step decreases
|
||||
|
||||
For example:
|
||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
58
examples/LiquidAI/README.md
Normal file
58
examples/LiquidAI/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
|
||||
|
||||
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
|
||||
|
||||
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
|
||||
|
||||
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run one of the finetuning examples below.
|
||||
|
||||
**LFM2**
|
||||
```bash
|
||||
# FFT SFT (1x48GB @ 25GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
|
||||
```
|
||||
|
||||
**LFM2-VL**
|
||||
```bash
|
||||
# LoRA SFT (1x48GB @ 2.7GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
```bash
|
||||
pip uninstall -y causal-conv1d
|
||||
```
|
||||
|
||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- **Dataset Formats**:
|
||||
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
|
||||
|
||||
chunked_cross_entropy: true
|
||||
|
||||
chat_template: tokenizer_default
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForImageTextToText
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -20,7 +20,7 @@ min_sample_len: 200_000
|
||||
sample_packing: true
|
||||
|
||||
tiled_mlp: true
|
||||
sequence_parallel_degree: 8
|
||||
context_parallel_size: 8
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
|
||||
53
examples/arcee/README.md
Normal file
53
examples/arcee/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Finetune ArceeAI's AFM with Axolotl
|
||||
|
||||
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/arcee/afm-4.5b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 7.8GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: arcee-ai/AFM-4.5B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -47,7 +47,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -10,17 +10,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
52
examples/distributed-parallel/README.md
Normal file
52
examples/distributed-parallel/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# ND Parallelism Examples
|
||||
|
||||
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Run the command below:
|
||||
|
||||
```bash
|
||||
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
|
||||
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
|
||||
|
||||
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
|
||||
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
|
||||
```
|
||||
|
||||
## Example Configurations
|
||||
|
||||
### Single Node (8 GPUs)
|
||||
|
||||
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
|
||||
- Uses all 3 parallelism dimensions on a single node
|
||||
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
|
||||
|
||||
```yaml
|
||||
dp_shard_size: 2 # FSDP across 2 GPUs
|
||||
tensor_parallel_size: 2 # TP across 2 GPUs
|
||||
context_parallel_size: 2 # CP across 2 GPUs
|
||||
# Total: 2 × 2 × 2 = 8 GPUs
|
||||
```
|
||||
|
||||
### Multi-Node
|
||||
|
||||
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
|
||||
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
|
||||
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
|
||||
|
||||
```yaml
|
||||
dp_shard_size: 4 # FSDP within each 4-GPU group
|
||||
tensor_parallel_size: 2 # TP within each node
|
||||
dp_replicate_size: 2 # DDP across 2 groups
|
||||
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
|
||||
```
|
||||
|
||||
## Learn More
|
||||
|
||||
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
|
||||
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
|
||||
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 4
|
||||
dp_replicate_size: 2
|
||||
tensor_parallel_size: 2
|
||||
# context_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 2
|
||||
# dp_replicate_size: 1
|
||||
context_parallel_size: 2
|
||||
tensor_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1 # must be 1 when using context parallel
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
@@ -4,17 +4,14 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||
|
||||
105
examples/gpt-oss/README.md
Normal file
105
examples/gpt-oss/README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Finetune OpenAI's GPT-OSS with Axolotl
|
||||
|
||||
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||
|
||||
```bash
|
||||
# LoRA SFT linear layers (1x48GB @ ~44GiB)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
|
||||
|
||||
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
|
||||
|
||||
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
|
||||
```
|
||||
|
||||
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
|
||||
|
||||
### Training 120B
|
||||
|
||||
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
|
||||
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
|
||||
|
||||
```bash
|
||||
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
```
|
||||
|
||||
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
||||
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
||||
|
||||
```bash
|
||||
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
|
||||
```
|
||||
|
||||
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
|
||||
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
|
||||
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
|
||||
weights to `{output_dir}/merged`.
|
||||
|
||||
```bash
|
||||
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||
```
|
||||
|
||||
|
||||
### Inferencing your fine-tuned model
|
||||
|
||||
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
||||
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
||||
|
||||
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
||||
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
|
||||
```
|
||||
|
||||
### Tool use
|
||||
|
||||
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
||||
|
||||
Here is an example dataset config:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: Nanobit/text-tools-2k-test
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
|
||||
|
||||
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
68
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
68
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
|
||||
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
|
||||
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
|
||||
|
||||
use_kernels: false
|
||||
|
||||
dp_shard_size: 16 # requires 2x8xH100 nodes
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
cpu_ram_efficient_loading: true
|
||||
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: false
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
# choose the zero3 configuration that best fits your system capabilities
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
|
||||
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
|
||||
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`
|
||||
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: false
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
|
||||
# TODO: not supported for now, see peft#2710
|
||||
#lora_target_parameters: # target the experts in the last two layers
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
@@ -1,7 +0,0 @@
|
||||
# Liquid Foundation Models 2
|
||||
|
||||
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
|
||||
|
||||
```bash
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
@@ -45,7 +45,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -49,7 +49,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -25,9 +25,12 @@ lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
relora_steps: 150
|
||||
relora_warmup_ratio: 0.1
|
||||
relora: true
|
||||
relora_prune_ratio: 0.9
|
||||
relora_cpu_offload: false
|
||||
jagged_restart_steps: 150
|
||||
jagged_restart_warmup_steps: 10
|
||||
jagged_restart_anneal_steps: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
57
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
57
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
@@ -0,0 +1,57 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
pretraining_dataset:
|
||||
- path: wikitext
|
||||
name: wikitext-103-raw-v1
|
||||
type: completion
|
||||
field: text
|
||||
|
||||
plugins:
|
||||
- diffusion.DiffusionPlugin
|
||||
noise_schedule: cosine
|
||||
min_mask_ratio: 0.15
|
||||
max_mask_ratio: 0.85
|
||||
eps: 5e-4
|
||||
importance_weighting: true
|
||||
mask_token_id: 128002
|
||||
generate_samples: true
|
||||
generation_interval: 10
|
||||
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: true
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 4
|
||||
max_steps: 10000
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 3e-4
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_steps: 1000
|
||||
|
||||
save_strategy: steps
|
||||
save_steps: 1000
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
58
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
58
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
val_set_size: 0.05
|
||||
|
||||
plugins:
|
||||
- diffusion.DiffusionPlugin
|
||||
noise_schedule: cosine
|
||||
min_mask_ratio: 0.1
|
||||
max_mask_ratio: 0.9
|
||||
num_diffusion_steps: 128
|
||||
eps: 1e-3
|
||||
importance_weighting: true
|
||||
mask_token_id: 128002
|
||||
|
||||
output_dir: ./outputs/model-out
|
||||
|
||||
sequence_len: 512
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 1e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
sdp_attention: true
|
||||
|
||||
warmup_steps: 1000
|
||||
|
||||
save_strategy: steps
|
||||
eval_strategy: steps
|
||||
save_steps: 500
|
||||
eval_steps: 500
|
||||
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -8,17 +8,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -27,7 +27,6 @@ sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
66
examples/slurm/README.md
Normal file
66
examples/slurm/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# SLURM Multi-Node Training
|
||||
|
||||
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Access to a SLURM cluster with GPU nodes
|
||||
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Usage
|
||||
|
||||
### Standard SLURM Clusters
|
||||
|
||||
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
|
||||
2. Place your Axolotl config file (`train.yaml`) in the same directory.
|
||||
3. Set the appropriate environment variables for the job:
|
||||
```bash
|
||||
export HF_TOKEN="your-huggingface-token"
|
||||
|
||||
# metric tracking
|
||||
# export WANDB_API_KEY="your-wandb-api-key"
|
||||
# ...
|
||||
```
|
||||
4. Submit the job:
|
||||
```bash
|
||||
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
|
||||
```
|
||||
|
||||
Where:
|
||||
- `NUM_NODES`: Number of nodes to use
|
||||
- `NUM_TRAINERS`: GPUs per node (typically 8)
|
||||
- `PRIMARY_ADDR`: Hostname/IP of the master node
|
||||
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
|
||||
|
||||
5. (Optional) Run other slurm commands:
|
||||
```bash
|
||||
# check job info
|
||||
scontrol show job axolotl-cli
|
||||
|
||||
# check job queue
|
||||
squeue
|
||||
|
||||
# check cluster status
|
||||
sinfo
|
||||
```
|
||||
|
||||
### RunPod Instant Clusters
|
||||
|
||||
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
|
||||
|
||||
1. **Deploy a SLURM Cluster**:
|
||||
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
|
||||
- Click "Create a Cluster"
|
||||
- Choose your GPU type, node count, and region
|
||||
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
|
||||
- Deploy the cluster
|
||||
|
||||
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
|
||||
|
||||
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
|
||||
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)
|
||||
20
examples/slurm/axolotl.slurm
Normal file
20
examples/slurm/axolotl.slurm
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
|
||||
# export HF_TOKEN="..."
|
||||
# export WANDB_API_KEY="..."
|
||||
#
|
||||
|
||||
# ---------- SBATCH commands ---------- #
|
||||
#SBATCH --job-name=axolotl-slurm-multinode
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --nodes=$NUM_NODES
|
||||
#SBATCH --gpus-per-task=8
|
||||
#SBATCH --cpus-per-task=128
|
||||
|
||||
export TORCH_DIST_INIT_BARRIER=0
|
||||
|
||||
srun axolotl preprocess train.yaml
|
||||
|
||||
srun axolotl train train.yaml --launcher torchrun -- \
|
||||
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
|
||||
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"
|
||||
49
examples/smolvlm2/README.md
Normal file
49
examples/smolvlm2/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Finetune SmolVLM2 with Axolotl
|
||||
|
||||
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
|
||||
|
||||
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
|
||||
|
||||
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install an extra dependency:
|
||||
|
||||
```bash
|
||||
pip3 install num2words==0.5.14
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# LoRA SFT (1x48GB @ 6.8GiB)
|
||||
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
|
||||
```
|
||||
|
||||
## TIPS
|
||||
|
||||
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
@@ -0,0 +1,56 @@
|
||||
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
trust_remote_code: true
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -6,17 +6,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Please install the below.
|
||||
|
||||
@@ -1,30 +1,33 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.46.0
|
||||
triton>=3.0.0
|
||||
bitsandbytes==0.47.0
|
||||
# triton 3.4.0 is not compatible with CCE
|
||||
triton>=3.0.0,<3.4.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.6.0
|
||||
liger-kernel==0.6.1
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft==0.16.0
|
||||
transformers==4.54.0
|
||||
peft==0.17.0
|
||||
transformers==4.55.2
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.9.0
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.19.1
|
||||
trl==0.21.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.23.3
|
||||
gradio==5.41.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic==2.10.6
|
||||
@@ -66,6 +69,6 @@ torchao==0.12.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
|
||||
@@ -44,8 +44,13 @@ add_keys_to_authorized() {
|
||||
chmod 700 -R ~/.ssh
|
||||
}
|
||||
|
||||
# Set SSH port
|
||||
if [ ! -z "$SSH_PORT" ]; then
|
||||
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
|
||||
fi
|
||||
|
||||
if [[ $PUBLIC_KEY ]]; then
|
||||
# runpod
|
||||
# runpod, prime intellect
|
||||
add_keys_to_authorized "$PUBLIC_KEY"
|
||||
# Start the SSH service in the background
|
||||
service ssh start
|
||||
@@ -76,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||
fi
|
||||
|
||||
# start the runpod slurm init
|
||||
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
|
||||
|
||||
if [[ -f "$SLURM_INIT" ]]; then
|
||||
echo "[entrypoint] running $SLURM_INIT..."
|
||||
bash "$SLURM_INIT"
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
exec "$@"
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
|
||||
)
|
||||
|
||||
3
setup.py
3
setup.py
@@ -72,12 +72,13 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map.pop("vllm")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.31")
|
||||
extras_require_map["vllm"] = ["vllm>=0.10.0"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
# since we only support 2.6.0+cu126
|
||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
extras_require_map.pop("vllm")
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.12.0.dev"
|
||||
__version__ = "0.13.0.dev"
|
||||
|
||||
@@ -30,8 +30,6 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=0)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -42,6 +40,12 @@ class VllmServeCliArgs:
|
||||
default=None,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
data_parallel_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
|
||||
},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default=None, # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
|
||||
@@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
@@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
@@ -31,9 +31,10 @@ def do_cli_preprocess(
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -44,12 +45,18 @@ def do_cli_train(
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
cloud.train(
|
||||
config_yaml,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
local_dirs=local_dirs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
|
||||
@@ -3,6 +3,7 @@ base class for cloud platforms from cli
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
@@ -15,5 +16,12 @@ class Cloud(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
import modal
|
||||
|
||||
@@ -230,8 +230,9 @@ class ModalCloud(Cloud):
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
@@ -239,7 +240,8 @@ class ModalCloud(Cloud):
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
@@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
def _train(
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
volumes=None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Build the base command
|
||||
if launcher == "accelerate":
|
||||
launcher_arg = "--launcher accelerate"
|
||||
elif launcher == "torchrun":
|
||||
launcher_arg = "--launcher torchrun"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
launcher_arg = "--launcher python"
|
||||
|
||||
# Build launcher args string
|
||||
launcher_args_str = ""
|
||||
if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
|
||||
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def plugin_set_cfg(cfg: DictDefault):
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.cfg = cfg
|
||||
# now that we have the finalized cfg, register the plugins individually
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def load_cfg(
|
||||
@@ -200,14 +199,13 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
for key, value in kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
cfg[key] = value
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Generator, Union
|
||||
import fire
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
@@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
@@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Union
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
from dotenv import load_dotenv
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
@@ -268,5 +267,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -4,12 +4,9 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
@@ -21,13 +18,14 @@ from axolotl.cli.args import (
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
generate_config_files,
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
@@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
LAUNCHER_COMMAND_MAPPING = {
|
||||
"accelerate": ["accelerate", "launch"],
|
||||
"torchrun": ["torchrun"],
|
||||
}
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -50,7 +55,7 @@ def cli():
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
@@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
@@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
@@ -88,126 +95,82 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
@click.pass_context
|
||||
def train(
|
||||
ctx: click.Context,
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
cloud: str | None = None,
|
||||
sweep: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
# Handle Ray launcher override
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
|
||||
else:
|
||||
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
# Process each configuration
|
||||
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
use_exec = is_group is not True
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
finally:
|
||||
# Only delete temp files, not the original config
|
||||
if cfg_file != config:
|
||||
os.unlink(cfg_file)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU evaluation",
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
|
||||
"""
|
||||
Evaluate a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.evaluate"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -218,30 +181,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU inference",
|
||||
)
|
||||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
|
||||
"""
|
||||
Run inference with a trained model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
|
||||
gradio: Whether to use Gradio browser interface or command line for inference.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.inference"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
if gradio:
|
||||
@@ -254,33 +229,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
do_cli(config=config, gradio=gradio, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for weight merging",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for weight merging",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def merge_sharded_fsdp_weights(
|
||||
ctx: click.Context, config: str, launcher: str, **kwargs
|
||||
):
|
||||
"""
|
||||
Merge sharded FSDP model weights.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.merge_sharded_fsdp_weights",
|
||||
]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -296,7 +280,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_lora(config: str, **kwargs) -> None:
|
||||
def merge_lora(config: str, **kwargs):
|
||||
"""
|
||||
Merge trained LoRA adapters into a base model.
|
||||
|
||||
@@ -313,7 +297,7 @@ def merge_lora(config: str, **kwargs) -> None:
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
def fetch(directory: str, dest: Optional[str]):
|
||||
"""
|
||||
Fetch example configs or other resources.
|
||||
|
||||
@@ -351,7 +335,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
def delinearize_llama4(model: str, output: str) -> None:
|
||||
def delinearize_llama4(model: str, output: str):
|
||||
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
|
||||
|
||||
do_delinearize_llama4(model, output)
|
||||
@@ -365,5 +349,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
@@ -70,7 +69,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
flash_attention=False,
|
||||
sequence_parallel_degree=None,
|
||||
context_parallel_size=None,
|
||||
deepspeed=None,
|
||||
fsdp=None,
|
||||
fsdp_config=None,
|
||||
@@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -10,6 +10,7 @@ import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
@@ -17,13 +18,13 @@ from accelerate.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_torch_version,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -144,7 +145,6 @@ def merge_fsdp_weights(
|
||||
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
||||
"""
|
||||
checkpoint_dir_ = Path(checkpoint_dir)
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if not is_torch_version(">=", "2.3.0"):
|
||||
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||
@@ -181,7 +181,6 @@ def merge_fsdp_weights(
|
||||
if remove_checkpoint_dir:
|
||||
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
||||
shutil.rmtree(checkpoint_dir_)
|
||||
state.wait_for_everyone()
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
@@ -196,13 +195,33 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
|
||||
if checkpoint_dir:
|
||||
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
raise ValueError(
|
||||
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
|
||||
)
|
||||
|
||||
output_path = str(Path(parsed_cfg.output_dir) / "merged")
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
|
||||
output_path=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
LOG.info(
|
||||
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.info(
|
||||
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ import fire
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
@@ -109,5 +108,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Union
|
||||
|
||||
import fire
|
||||
from accelerate import Accelerator
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
@@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Utility methods for axolotl CLI."""
|
||||
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
cmd.append(f"--{key}")
|
||||
else:
|
||||
cmd.extend([f"--{key}", str(value)])
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "new"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
23
src/axolotl/cli/utils/__init__.py
Normal file
23
src/axolotl/cli/utils/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Init for axolotl.cli.utils module."""
|
||||
|
||||
from .args import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from .fetch import fetch_from_github
|
||||
from .load import load_model_and_tokenizer
|
||||
from .sweeps import generate_sweep_configs
|
||||
from .train import build_command, generate_config_files, launch_training
|
||||
|
||||
__all__ = [
|
||||
"filter_none_kwargs",
|
||||
"add_options_from_dataclass",
|
||||
"add_options_from_config",
|
||||
"build_command",
|
||||
"generate_config_files",
|
||||
"generate_sweep_configs",
|
||||
"load_model_and_tokenizer",
|
||||
"launch_training",
|
||||
"fetch_from_github",
|
||||
]
|
||||
120
src/axolotl/cli/utils/args.py
Normal file
120
src/axolotl/cli/utils/args.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Utilities for axolotl CLI args."""
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = _strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
142
src/axolotl/cli/utils/fetch.py
Normal file
142
src/axolotl/cli/utils/fetch.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Utilities for axolotl fetch CLI command."""
|
||||
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "updated"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
_download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
52
src/axolotl/cli/utils/load.py
Normal file
52
src/axolotl/cli/utils/load.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Utilities for model, tokenizer, etc. loading."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
220
src/axolotl/cli/utils/train.py
Normal file
220
src/axolotl/cli/utils/train.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Utilities for axolotl train CLI command."""
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.utils.sweeps import generate_sweep_configs
|
||||
|
||||
|
||||
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
|
||||
"""
|
||||
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
|
||||
|
||||
Args:
|
||||
launcher_args: List of launcher arguments
|
||||
|
||||
Returns:
|
||||
Updated launcher args with defaults added if needed
|
||||
"""
|
||||
args = launcher_args.copy()
|
||||
|
||||
# Check if rdzv_endpoint is present
|
||||
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
|
||||
|
||||
if has_rdzv_endpoint:
|
||||
# Check if rdzv_backend is already provided
|
||||
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
|
||||
if not has_rdzv_backend:
|
||||
args.extend(["--rdzv_backend", "c10d"])
|
||||
|
||||
# Check if rdzv_id is already provided
|
||||
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
|
||||
if not has_rdzv_id:
|
||||
import uuid
|
||||
|
||||
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
cmd.append(f"--{key}={value}")
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||
"""
|
||||
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
|
||||
whether this is a group of configurations (i.e., a sweep).
|
||||
|
||||
Args:
|
||||
config: Base configuration file
|
||||
sweep: Sweep configuration file
|
||||
"""
|
||||
|
||||
if not sweep:
|
||||
yield config, False
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
is_group = len(permutations) > 1
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name, is_group
|
||||
|
||||
|
||||
def launch_training(
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
if cloud:
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
elif launcher is None:
|
||||
# handle ray train launch
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
cloud: str,
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training via cloud launcher."""
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
cwd = os.getcwd() if launcher else None
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=cfg_file,
|
||||
launcher=launcher or "accelerate",
|
||||
launcher_args=launcher_args,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
internal_launcher_args = []
|
||||
|
||||
# Extract launcher-specific arguments from kwargs (legacy support)
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port")
|
||||
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
|
||||
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes")
|
||||
internal_launcher_args.extend(["--num_processes", str(num_processes)])
|
||||
|
||||
# Combine internal args with user-provided launcher args
|
||||
all_launcher_args = internal_launcher_args + launcher_args
|
||||
|
||||
base_cmd = (
|
||||
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
|
||||
)
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Add default RDZV arguments if rdzv_endpoint is set
|
||||
launcher_args = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
"""Execute training via python launcher."""
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
@@ -2,12 +2,10 @@
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import trl
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
@@ -42,13 +40,17 @@ def do_vllm_serve(
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
@@ -81,63 +83,3 @@ def do_vllm_serve(
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
|
||||
def patch_vllm_worker():
|
||||
from multiprocessing.connection import Connection
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
def llm_worker(
|
||||
script_args: AxolotlScriptArguments,
|
||||
data_parallel_rank: int,
|
||||
master_port: int,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
# Set required environment variables for DP to work with vLLM
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
revision=script_args.revision,
|
||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
enable_reasoning=script_args.enable_reasoning,
|
||||
reasoning_parser=script_args.reasoning_parser,
|
||||
)
|
||||
|
||||
# Send ready signal to parent process
|
||||
connection.send({"status": "ready"})
|
||||
|
||||
while True:
|
||||
# Wait for commands from the parent process
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
llm.collective_rpc(method="close_communicator")
|
||||
break
|
||||
|
||||
# Handle commands
|
||||
if command["type"] in ["call", "fire_and_forget"]:
|
||||
method_name = command["method"]
|
||||
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
trl.scripts.vllm_serve.llm_worker = llm_worker
|
||||
|
||||
@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -27,18 +27,18 @@ import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -139,8 +139,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
@@ -268,27 +266,24 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "dion":
|
||||
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
|
||||
DionOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = DionOptimizerFactory
|
||||
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
if device_mesh is not None:
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
@@ -435,7 +430,11 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
**self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
@@ -495,10 +494,20 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"seed",
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
|
||||
arg_map = {
|
||||
"dion_learning_rate": "dion_lr",
|
||||
}
|
||||
for kwarg, cfg_arg in arg_map.items():
|
||||
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
|
||||
|
||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||
training_args_kwargs["average_tokens_across_devices"] = False
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import transformers
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
Trainer,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
@@ -19,7 +20,6 @@ from axolotl.core.trainers import (
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
@@ -44,6 +44,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -58,7 +59,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
@@ -131,14 +132,24 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
if trainer_cls:
|
||||
return trainer_cls
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
return AxolotlRewardTrainer
|
||||
if self.cfg.process_reward_model:
|
||||
return AxolotlPRMTrainer
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
return trainer_cls
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -271,20 +282,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.sample_packing_eff_est
|
||||
)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||
self.cfg.relora_warmup_steps
|
||||
)
|
||||
if self.cfg.relora_anneal_steps:
|
||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||
self.cfg.relora_anneal_steps
|
||||
)
|
||||
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||
if self.cfg.relora_prune_ratio:
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
|
||||
if self.cfg.jagged_restart_steps:
|
||||
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||
self.cfg.jagged_restart_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_warmup_steps:
|
||||
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
|
||||
self.cfg.jagged_restart_warmup_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_anneal_steps:
|
||||
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
|
||||
self.cfg.jagged_restart_anneal_steps
|
||||
)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||
training_arguments_kwargs["lisa_step_interval"] = (
|
||||
@@ -348,7 +364,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||
self.cfg.sequence_len / multiple
|
||||
)
|
||||
else:
|
||||
elif self.cfg.pad_to_sequence_len is None:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
@@ -370,10 +386,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters:
|
||||
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
elif "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
if (
|
||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||
and self.cfg.datasets is not None
|
||||
|
||||
@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -53,7 +54,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
|
||||
@@ -5,9 +5,7 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
|
||||
@@ -10,8 +10,11 @@ from functools import partial, wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
import safetensors
|
||||
import torch
|
||||
from accelerate.state import AcceleratorState
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -19,14 +22,17 @@ from torch.utils.data import (
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
DistributedParallelMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
@@ -37,6 +43,8 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils import get_not_null
|
||||
from axolotl.utils.bench import get_gpu_memory_usage
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -50,6 +58,7 @@ class AxolotlTrainer(
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
ActivationOffloadingMixin,
|
||||
DistributedParallelMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
@@ -73,7 +82,9 @@ class AxolotlTrainer(
|
||||
super().__init__(*_args, **kwargs)
|
||||
|
||||
self.train_data_collator = self.data_collator
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
self._stored_metrics = defaultdict(
|
||||
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||
)
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
@@ -511,7 +522,18 @@ class AxolotlTrainer(
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
||||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
use_configured_state = accelerator_config.get("use_configured_state", False)
|
||||
if not use_configured_state:
|
||||
AcceleratorState._reset_state( # pylint: disable=protected-access
|
||||
reset_partial_state=True
|
||||
)
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
@@ -520,8 +542,6 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
return res
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
@@ -555,18 +575,63 @@ class AxolotlTrainer(
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
|
||||
# Add reduced stored metrics to logs
|
||||
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||
values = torch.tensor(metric_data["values"])
|
||||
reduction_type = metric_data["reduction"]
|
||||
|
||||
if reduction_type == "mean":
|
||||
logs[key] = values.mean().item()
|
||||
elif reduction_type == "min":
|
||||
logs[key] = values.min().item()
|
||||
elif reduction_type == "max":
|
||||
logs[key] = values.max().item()
|
||||
elif reduction_type == "sum":
|
||||
logs[key] = values.sum().item()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Metric reduction must be one of [mean, min, max, sum]"
|
||||
)
|
||||
|
||||
logs[key] = round(logs[key], 4)
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
self,
|
||||
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
reduction: Literal["mean", "min", "max", "sum"] = "mean",
|
||||
) -> None:
|
||||
"""
|
||||
Store metrics with specified reduction type.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values, or metric names to (value,
|
||||
reduction_type) tuples.
|
||||
train_eval: Whether this is for training or evaluation.
|
||||
"""
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
if isinstance(value, tuple):
|
||||
metric_value, metric_reduction = value
|
||||
else:
|
||||
metric_value, metric_reduction = value, reduction
|
||||
|
||||
self._stored_metrics[train_eval][key]["values"].append(metric_value)
|
||||
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
|
||||
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
@@ -575,3 +640,64 @@ class AxolotlTrainer(
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
else (PreTrainedModel, PeftModel)
|
||||
)
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
):
|
||||
self.accelerator.unwrap_model(
|
||||
self.model, keep_torch_compile=False
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||
)
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -8,7 +8,11 @@ import torch
|
||||
from torch import nn
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
@@ -17,7 +21,12 @@ from axolotl.core.trainers.utils import (
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, DPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DPOTrainer,
|
||||
DistributedParallelMixin,
|
||||
):
|
||||
"""Extend the base DPOTrainer for axolotl helpers."""
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ class GRPOStrategy:
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode:
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
@@ -82,8 +83,13 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||
|
||||
if cfg.sequence_parallel_degree > 1:
|
||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
||||
if cfg.context_parallel_size > 1:
|
||||
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
||||
|
||||
if trl.importance_sampling_level is not None:
|
||||
grpo_args_kwargs["importance_sampling_level"] = (
|
||||
trl.importance_sampling_level
|
||||
)
|
||||
|
||||
if trl.reward_weights:
|
||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||
|
||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
sequence_parallel_degree: int | None = None
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
- Data is properly distributed across SP groups.
|
||||
|
||||
In the table below, the values represent dataset indices. Each SP group has
|
||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
||||
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||
|
||||
Sequence Parallel Groups
|
||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: Rank of current process.
|
||||
batch_size: Number of samples per batch.
|
||||
repeat_count: How many times to repeat the full sampling process.
|
||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
||||
context_parallel_size: Number of ranks in a sequence parallel group.
|
||||
shuffle: Whether to shuffle the dataset.
|
||||
seed: Random seed for shuffling.
|
||||
drop_last: Whether to drop the last incomplete batch.
|
||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
rank: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
sequence_parallel_degree: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
||||
self.rank = rank
|
||||
|
||||
# Sequence parallelism parameters
|
||||
self.sequence_parallel_degree = sequence_parallel_degree
|
||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
||||
self.sp_group_id = rank // sequence_parallel_degree
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.num_sp_groups = world_size // context_parallel_size
|
||||
self.sp_group_id = rank // context_parallel_size
|
||||
|
||||
# Adjust dataset size for distributed sampling
|
||||
self.num_samples = len(self.dataset)
|
||||
|
||||
@@ -43,7 +43,11 @@ from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.core.trainers.mixins import (
|
||||
DistributedParallelMixin,
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
)
|
||||
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin
|
||||
from axolotl.monkeypatch.ring_attn import get_ring_attn_group
|
||||
|
||||
@@ -53,7 +57,12 @@ if is_peft_available():
|
||||
|
||||
|
||||
class AxolotlGRPOTrainer(
|
||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, GRPOTrainer
|
||||
RngLoaderMixin,
|
||||
SchedulerMixin,
|
||||
OptimizerMixin,
|
||||
OptimizerInitMixin,
|
||||
DistributedParallelMixin,
|
||||
GRPOTrainer,
|
||||
):
|
||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||
|
||||
@@ -100,7 +109,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
# Get number of SP groups (number of processes divided by SP degree)
|
||||
num_processes = self.accelerator.num_processes
|
||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
||||
num_sp_groups = num_processes // self.args.context_parallel_size
|
||||
|
||||
# Calculate batch size per SP group (not per process)
|
||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||
@@ -130,7 +139,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
||||
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||
f"must be evenly divisible by the number of generations per prompt "
|
||||
f"({self.num_generations}). Given the current eval batch size, "
|
||||
@@ -167,9 +176,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
rank=self.rank,
|
||||
batch_size=effective_batch_size
|
||||
// self.num_generations
|
||||
// self.args.sequence_parallel_degree,
|
||||
// self.args.context_parallel_size,
|
||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
||||
context_parallel_size=self.args.context_parallel_size,
|
||||
shuffle=True,
|
||||
seed=self.args.seed,
|
||||
drop_last=True,
|
||||
@@ -235,7 +244,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||
# slice each batch along the sequence dimension).
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
return dataloader
|
||||
|
||||
# Otherwise prepare with accelerator
|
||||
@@ -308,18 +317,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate sequence parallel group information
|
||||
world_size = self.accelerator.num_processes
|
||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
||||
num_sp_groups = world_size // sequence_parallel_degree
|
||||
context_parallel_size = self.args.context_parallel_size
|
||||
num_sp_groups = world_size // context_parallel_size
|
||||
|
||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||
# we only take one copy of each prompt from each SP group
|
||||
ordered_set_of_prompts = []
|
||||
for sp_group_id in range(num_sp_groups):
|
||||
# Get the first process from each SP group (typically the group leader)
|
||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
||||
group_leader_rank = sp_group_id * context_parallel_size
|
||||
|
||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||
# We only need prompts from one rank in each SP group
|
||||
@@ -335,7 +344,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = all_prompts_text[
|
||||
:: self.num_generations * self.args.sequence_parallel_degree
|
||||
:: self.num_generations * self.args.context_parallel_size
|
||||
]
|
||||
|
||||
with profiling_context(self, "vLLM.generate"):
|
||||
@@ -352,14 +361,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
)
|
||||
else:
|
||||
completion_ids = [None] * (
|
||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
||||
len(all_prompts_text) // self.args.context_parallel_size
|
||||
)
|
||||
|
||||
# Broadcast the completions from the main process to all processes
|
||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||
|
||||
# Determine the appropriate slice based on sequence parallelism
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
@@ -583,7 +592,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
if self.args.context_parallel_size > 1:
|
||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
"""Mamba specific trainer to handle loss calculation"""
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .distributed_parallel import DistributedParallelMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
from .rng_state_loader import RngLoaderMixin
|
||||
|
||||
@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
try:
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
except NotImplementedError as exc:
|
||||
LOG.warning(
|
||||
except (NotImplementedError, KeyError) as exc:
|
||||
# TODO: fix fsdp2 optimizer saving
|
||||
LOG.warning_once(
|
||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||
"for this training run will not be possible."
|
||||
"for this training run will not be possible.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
33
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
33
src/axolotl/core/trainers/mixins/distributed_parallel.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
from accelerate import PartialState
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
class DistributedParallelMixin(Trainer):
|
||||
"""
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
def _save(self, output_dir: str | None = None, state_dict=None):
|
||||
if (
|
||||
state_dict is None
|
||||
and self.accelerator.parallelism_config
|
||||
and self.accelerator.parallelism_config.dp_shard_enabled
|
||||
):
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
super()._save(output_dir, state_dict=state_dict)
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
super().create_accelerator_and_postprocess()
|
||||
if (
|
||||
self.accelerator.distributed_type == "FSDP"
|
||||
and self.accelerator.state.fsdp_plugin is None
|
||||
):
|
||||
# pylint: disable=protected-access
|
||||
# handle Context Parallelism without FSDP
|
||||
self.accelerator.state.distributed_type = "MULTI_GPU"
|
||||
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
||||
PartialState().distributed_type = "MULTI_GPU"
|
||||
@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schedulers import (
|
||||
JaggedLRRestartScheduler,
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
@@ -113,7 +114,7 @@ class SchedulerMixin(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning(
|
||||
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
|
||||
LOG.warning(
|
||||
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
if self.args.jagged_restart_steps:
|
||||
warmup_steps = (
|
||||
self.args.jagged_restart_warmup_steps or 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.jagged_restart_anneal_steps or 1
|
||||
)
|
||||
if not self.lr_scheduler:
|
||||
super().create_scheduler(num_training_steps, optimizer)
|
||||
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
self.lr_scheduler,
|
||||
self.args.jagged_restart_steps,
|
||||
warmup_steps,
|
||||
anneal_steps,
|
||||
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
|
||||
)
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Module for ReLoRA trainer"""
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
) -> LRScheduler:
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler: LRScheduler = super().create_scheduler(
|
||||
num_training_steps, optimizer
|
||||
)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler( # type: ignore
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler # type: ignore
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user