Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
b5198d8734 granite chat multipack support and example 2025-08-02 20:57:00 -04:00
Wing Lian
4ab6a1bd7e add support for granite chat templates 2025-08-02 11:29:03 -04:00
142 changed files with 12211 additions and 16493 deletions

View File

@@ -12,6 +12,5 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
chat:
auto_reply: true

View File

@@ -57,13 +57,6 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
#### Skipping CI Checks
You can skip certain CI checks by including specific keywords in your commit messages:
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
## Style Guidelines
### Code Style

View File

@@ -54,7 +54,7 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
@@ -64,16 +64,9 @@ jobs:
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.8.0
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: nightly
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-nightly"
dockerfile: "Dockerfile-base-nightly"
# # "next" is for release candidates of pytorch
# - cuda: "128"
# cuda_version: 12.8.1
@@ -129,13 +122,6 @@ jobs:
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -143,13 +129,6 @@ jobs:
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -24,13 +24,12 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
@@ -98,12 +97,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
@@ -157,18 +150,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -105,8 +105,7 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
@@ -180,52 +179,21 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
steps:
- uses: actions/github-script@v7
id: compute
with:
script: |
const token = /\[skip-e2e\]/i;
let msg = '';
if (context.eventName === 'push') {
msg = context.payload.head_commit?.message || '';
} else if (context.eventName === 'pull_request') {
const { owner, repo } = context.repo;
const prNumber = context.payload.pull_request.number;
const commits = await github.paginate(
github.rest.pulls.listCommits,
{ owner, repo, pull_number: prNumber, per_page: 100 }
);
msg = commits.at(-1)?.commit?.message || '';
}
const title = context.payload.pull_request?.title || '';
const body = context.payload.pull_request?.body || '';
const skip = token.test(msg) || token.test(title) || token.test(body);
core.setOutput('skip', String(skip));
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
needs: [pre-commit, pytest, pytest-sdist]
strategy:
fail-fast: false
@@ -271,16 +239,13 @@ jobs:
modal run cicd.e2e_tests
docker-e2e-tests:
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
fail-fast: false

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@@ -23,11 +23,11 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.8
rev: v3.3.7
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
rev: v1.17.0
hooks:
- id: mypy
additional_dependencies:

View File

@@ -185,6 +185,7 @@ datasets:
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |

View File

@@ -296,6 +296,7 @@
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -540,6 +541,7 @@ xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}

View File

@@ -1,10 +0,0 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
url: "https://axolotl.ai/"
license: Apache-2.0
date-released: "2023-05-30"

View File

@@ -25,28 +25,17 @@
## 🎉 Latest Updates
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
</details>
## ✨ Overview
Axolotl is a tool designed to streamline post-training for various AI models.
@@ -149,20 +138,6 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📝 Citing Axolotl
If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},
year = {2023}
}
```
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

10
TODO.md Normal file
View File

@@ -0,0 +1,10 @@
# todo list
- [] Validation of parameters for combinations that won't work
## things that are known not to work
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
- adamw_bnb_8bit doesn't play well with FSDP offload

View File

@@ -274,7 +274,6 @@ website:
- docs/dataset_preprocessing.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- section: "Advanced Features"
contents:

View File

@@ -37,7 +37,7 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge

View File

@@ -212,11 +212,10 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
Example config for Llama4:
```yaml
chat_template: llama4
datasets:
- path: Nanobit/text-tools-2k-test
- path: ...
type: chat_template
# field_tools: tools # default is `tools`
```

View File

@@ -13,13 +13,10 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
## Usage
@@ -34,7 +31,7 @@ skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal
chat_template: # see in next section if specified
chat_template: # see in next section
# example dataset
datasets:
@@ -100,16 +97,6 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken
```
### Voxtral {#sec-voxtral}
::: {.callout-tip}
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
:::
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
```
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}
@@ -156,26 +143,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
:::
```yaml
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
```
### LFM2-VL {#sec-lfm2-vl}
::: {.callout-warning}
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
:::
```yaml
base_model: LiquidAI/LFM2-VL-450M
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
@@ -214,20 +181,6 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::
### Video
::: {.callout-warning}
This is not well tested at the moment. We welcome contributors!
:::
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
- `"path": "/path/to/video.mp4"`
- `"url": "https://example.com/video.mp4"`
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
### Example
Here is an example of a multi-modal dataset:

View File

@@ -1,6 +1,4 @@
---
title: "N-D Parallelism (Beta)"
---
# N-D Parallelism
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
@@ -73,10 +71,6 @@ Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size
## Examples
::: {.callout-tip}
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
:::
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
- You want FSDP within each node and DDP across nodes.
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
@@ -101,7 +95,7 @@ This matrix describes how different parallelism methods can be combined in Axolo
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
- `tp_size` refers to `tensor_parallel_size`

View File

@@ -1,129 +0,0 @@
---
title: Optimizers
description: Configuring optimizers
---
## Overview
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
Here is a list of optimizers supported by transformers as of `v4.54.0`:
- `adamw_torch`
- `adamw_torch_fused`
- `adamw_torch_xla`
- `adamw_torch_npu_fused`
- `adamw_apex_fused`
- `adafactor`
- `adamw_anyprecision`
- `adamw_torch_4bit`
- `adamw_torch_8bit`
- `ademamix`
- `sgd`
- `adagrad`
- `adamw_bnb_8bit`
- `adamw_8bit` # alias for adamw_bnb_8bit
- `ademamix_8bit`
- `lion_8bit`
- `lion_32bit`
- `paged_adamw_32bit`
- `paged_adamw_8bit`
- `paged_ademamix_32bit`
- `paged_ademamix_8bit`
- `paged_lion_32bit`
- `paged_lion_8bit`
- `rmsprop`
- `rmsprop_bnb`
- `rmsprop_bnb_8bit`
- `rmsprop_bnb_32bit`
- `galore_adamw`
- `galore_adamw_8bit`
- `galore_adafactor`
- `galore_adamw_layerwise`
- `galore_adamw_8bit_layerwise`
- `galore_adafactor_layerwise`
- `lomo`
- `adalomo`
- `grokadamw`
- `schedule_free_radam`
- `schedule_free_adamw`
- `schedule_free_sgd`
- `apollo_adamw`
- `apollo_adamw_layerwise`
- `stable_adamw`
## Custom Optimizers
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
### optimi_adamw
```yaml
optimizer: optimi_adamw
```
### ao_adamw_4bit
Deprecated: Please use `adamw_torch_4bit`.
### ao_adamw_8bit
Deprecated: Please use `adamw_torch_8bit`.
### ao_adamw_fp8
```yaml
optimizer: ao_adamw_fp8
```
### adopt_adamw
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
```yaml
optimizer: adopt_adamw
```
### came_pytorch
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
```yaml
optimizer: came_pytorch
# optional args (defaults below)
adam_beta1: 0.9
adam_beta2: 0.999
adam_beta3: 0.9999
adam_epsilon: 1e-30
adam_epsilon2: 1e-16
```
### muon
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
```yaml
optimizer: muon
```
### dion
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
Note: Implementation written for PyTorch 2.7+ for DTensor
```yaml
optimizer: dion
dion_lr: 0.01
dion_momentum: 0.95
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
```

View File

@@ -1,58 +0,0 @@
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
**LFM2**
```bash
# FFT SFT (1x48GB @ 25GiB)
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
```
**LFM2-VL**
```bash
# LoRA SFT (1x48GB @ 2.7GiB)
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
- **Dataset Formats**:
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,58 +0,0 @@
base_model: LiquidAI/LFM2-VL-450M
trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,53 +0,0 @@
# Finetune ArceeAI's AFM with Axolotl
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:
```bash
axolotl train examples/arcee/afm-4.5b-qlora.yaml
```
This config uses about 7.8GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: arcee-ai/AFM-4.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -47,6 +47,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

File diff suppressed because it is too large Load Diff

View File

@@ -10,14 +10,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:

View File

@@ -1,52 +0,0 @@
# ND Parallelism Examples
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
## Quick Start
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Run the command below:
```bash
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
```
## Example Configurations
### Single Node (8 GPUs)
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
- Uses all 3 parallelism dimensions on a single node
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
```yaml
dp_shard_size: 2 # FSDP across 2 GPUs
tensor_parallel_size: 2 # TP across 2 GPUs
context_parallel_size: 2 # CP across 2 GPUs
# Total: 2 × 2 × 2 = 8 GPUs
```
### Multi-Node
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
```yaml
dp_shard_size: 4 # FSDP within each 4-GPU group
tensor_parallel_size: 2 # TP within each node
dp_replicate_size: 2 # DDP across 2 groups
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
```
## Learn More
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,47 +0,0 @@
base_model: meta-llama/Llama-3.1-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 4
dp_replicate_size: 2
tensor_parallel_size: 2
# context_parallel_size: 2
dataset_prepared_path: last_run_prepared
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1

View File

@@ -1,46 +0,0 @@
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 2
# dp_replicate_size: 1
context_parallel_size: 2
tensor_parallel_size: 2
dataset_prepared_path: last_run_prepared
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 8192
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:

View File

@@ -4,14 +4,17 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:

View File

@@ -1,125 +0,0 @@
# Finetune OpenAI's GPT-OSS with Axolotl
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
```bash
# LoRA SFT linear layers (1x48GB @ ~44GiB)
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
```
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
### Training 120B
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
```bash
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
```bash
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
```
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
weights to `{output_dir}/merged`.
```bash
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
```
### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
```bash
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
```
### Tool use
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
Here is an example dataset config:
```yaml
datasets:
- path: Nanobit/text-tools-2k-test
type: chat_template
```
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,68 +0,0 @@
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
use_kernels: false
dp_shard_size: 16 # requires 2x8xH100 nodes
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
cpu_ram_efficient_loading: true

View File

@@ -1,58 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
# choose the zero3 configuration that best fits your system capabilities
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -1,68 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`

View File

@@ -1,64 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true

View File

@@ -1,67 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

7
examples/lfm2/README.md Normal file
View File

@@ -0,0 +1,7 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -2,6 +2,7 @@ base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens:
- "<|im_end|>"
datasets:

View File

@@ -45,6 +45,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -49,6 +49,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -8,14 +8,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:

View File

@@ -27,6 +27,7 @@ sequence_len: 2048
sample_packing: true
eval_sample_packing: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,6 +26,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,6 +26,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,66 +0,0 @@
# SLURM Multi-Node Training
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
## Prerequisites
- Access to a SLURM cluster with GPU nodes
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
## Usage
### Standard SLURM Clusters
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
2. Place your Axolotl config file (`train.yaml`) in the same directory.
3. Set the appropriate environment variables for the job:
```bash
export HF_TOKEN="your-huggingface-token"
# metric tracking
# export WANDB_API_KEY="your-wandb-api-key"
# ...
```
4. Submit the job:
```bash
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
```
Where:
- `NUM_NODES`: Number of nodes to use
- `NUM_TRAINERS`: GPUs per node (typically 8)
- `PRIMARY_ADDR`: Hostname/IP of the master node
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
5. (Optional) Run other slurm commands:
```bash
# check job info
scontrol show job axolotl-cli
# check job queue
squeue
# check cluster status
sinfo
```
### RunPod Instant Clusters
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
1. **Deploy a SLURM Cluster**:
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
- Click "Create a Cluster"
- Choose your GPU type, node count, and region
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
- Deploy the cluster
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
## Additional Resources
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)

View File

@@ -1,20 +0,0 @@
#!/bin/bash
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
# export HF_TOKEN="..."
# export WANDB_API_KEY="..."
#
# ---------- SBATCH commands ---------- #
#SBATCH --job-name=axolotl-slurm-multinode
#SBATCH --ntasks-per-node=1
#SBATCH --nodes=$NUM_NODES
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-task=128
export TORCH_DIST_INIT_BARRIER=0
srun axolotl preprocess train.yaml
srun axolotl train train.yaml --launcher torchrun -- \
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"

View File

@@ -1,49 +0,0 @@
# Finetune SmolVLM2 with Axolotl
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
pip3 install num2words==0.5.14
```
3. Run the finetuning example:
```bash
# LoRA SFT (1x48GB @ 6.8GiB)
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
```
## TIPS
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,56 +0,0 @@
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
trust_remote_code: true
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -6,14 +6,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Please install the below.

View File

@@ -1,9 +1,8 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
bitsandbytes==0.46.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
@@ -13,21 +12,19 @@ liger-kernel==0.6.1
packaging==23.2
huggingface_hub>=0.33.0
peft==0.17.0
transformers==4.55.2
peft==0.16.0
transformers==4.54.1
tokenizers>=0.21.1
accelerate==1.10.0
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
datasets==4.0.0
deepspeed>=0.17.0
trl==0.21.0
trl==0.20.0
hf_xet==1.1.5
kernels==0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.41.1
gradio==5.23.3
modal==1.0.2
pydantic==2.10.6
@@ -69,11 +66,6 @@ torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
axolotl-contribs-mit==0.0.3
mistral-common==1.8.3
# TUI dependencies
textual==1.0.0
rich==14.1.0
tree_sitter_ruby==0.23.1

View File

@@ -44,13 +44,8 @@ add_keys_to_authorized() {
chmod 700 -R ~/.ssh
}
# Set SSH port
if [ ! -z "$SSH_PORT" ]; then
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
fi
if [[ $PUBLIC_KEY ]]; then
# runpod, prime intellect
# runpod
add_keys_to_authorized "$PUBLIC_KEY"
# Start the SSH service in the background
service ssh start
@@ -81,13 +76,5 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# start the runpod slurm init
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
if [[ -f "$SLURM_INIT" ]]; then
echo "[entrypoint] running $SLURM_INIT..."
bash "$SLURM_INIT"
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
)

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.3"],
"flash-attn": ["flash-attn==2.8.2"],
"ring-flash-attn": [
"flash-attn==2.8.3",
"flash-attn==2.8.2",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.13.0.dev"
__version__ = "0.12.0.dev"

View File

@@ -40,12 +40,6 @@ class VllmServeCliArgs:
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu126-2.7.1"
docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return f"H100:{count}"
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":

View File

@@ -153,14 +153,15 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
# now that we have the finalized cfg, register the plugins individually
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def load_cfg(

View File

@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer

View File

@@ -123,10 +123,9 @@ def train(
_launcher = None if kwargs.get("use_ray") else launcher
# Process each configuration
for cfg_file, is_group in generate_config_files(config, sweep):
for cfg_file in generate_config_files(config, sweep):
try:
use_exec = is_group is not True
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
@@ -344,26 +343,6 @@ def delinearize_llama4(model: str, output: str):
cli.add_command(lm_eval)
@cli.command()
def tui():
"""
Launch the Axolotl Terminal User Interface (TUI).
Provides an interactive interface for configuration management,
training monitoring, dataset handling, and model operations.
"""
try:
from axolotl.tui.app import run
run()
except ImportError:
click.echo(
"TUI dependencies not installed. Install with: pip install textual rich"
)
except Exception as e:
click.echo(f"Error launching TUI: {e}")
def main():
cli()

View File

@@ -10,7 +10,6 @@ import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@@ -24,7 +23,6 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
from axolotl.utils.train import determine_last_checkpoint
LOG = get_logger(__name__)
@@ -145,6 +143,7 @@ def merge_fsdp_weights(
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
@@ -181,6 +180,7 @@ def merge_fsdp_weights(
if remove_checkpoint_dir:
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
shutil.rmtree(checkpoint_dir_)
state.wait_for_everyone()
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
@@ -195,32 +195,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
if checkpoint_dir:
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
raise ValueError(
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
)
output_path = str(Path(parsed_cfg.output_dir) / "merged")
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=output_path,
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()
LOG.info(
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
main_process_only=True,
)
LOG.info(
"Merged weights are only the safetensors and doesn't include the model configuration "
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
main_process_only=True,
)
if __name__ == "__main__":

View File

@@ -97,8 +97,7 @@ def do_cli(
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -3,12 +3,11 @@
import random
from copy import deepcopy
from itertools import product
from typing import Any
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, Any]]:
) -> list[dict[str, list]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.

View File

@@ -2,9 +2,7 @@
import os
import subprocess # nosec
import sys
import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal
import yaml
@@ -66,18 +64,10 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
return cmd
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
"""
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
whether this is a group of configurations (i.e., a sweep).
Args:
config: Base configuration file
sweep: Sweep configuration file
"""
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
"""Generate list of configuration files to process."""
if not sweep:
yield config, False
yield config
return
# Load sweep and base configurations
@@ -88,13 +78,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
base_output_dir = base_config.get("output_dir", "./model-out")
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",
@@ -104,7 +88,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
)
yaml.dump(permutation, temp_file)
temp_file.close()
yield temp_file.name, is_group
yield temp_file.name
def launch_training(
@@ -113,7 +97,6 @@ def launch_training(
cloud: str | None,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
) -> None:
"""Execute training with the given configuration."""
launcher_args = launcher_args or []
@@ -122,14 +105,11 @@ def launch_training(
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
elif launcher:
if launcher == "accelerate":
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
elif launcher == "torchrun":
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
elif launcher == "python":
_launch_python_training(cfg_file, kwargs)
elif launcher is None:
# handle ray train launch
_launch_python_training(cfg_file, kwargs)
def _launch_cloud_training(
@@ -156,10 +136,7 @@ def _launch_cloud_training(
def _launch_accelerate_training(
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
) -> None:
"""Execute training via accelerate launcher."""
launcher_args = launcher_args or []
@@ -184,20 +161,11 @@ def _launch_accelerate_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
subprocess.run(cmd, check=True) # nosec B603
def _launch_torchrun_training(
cfg_file: str,
kwargs: dict,
launcher_args: list[str] | None = None,
use_exec: bool = False,
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
) -> None:
"""Execute training via torchrun launcher."""
launcher_args = launcher_args or []
@@ -210,13 +178,7 @@ def _launch_torchrun_training(
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
if use_exec:
# make sure to flush stdout and stderr before replacing the process
sys.stdout.flush()
sys.stderr.flush()
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
else:
subprocess.run(cmd, check=True) # nosec B603
subprocess.run(cmd, check=True) # nosec B603
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:

View File

@@ -2,10 +2,12 @@
CLI to start the vllm server for online RL
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import trl
from trl.scripts.vllm_serve import ScriptArguments
from axolotl.cli.config import load_cfg
@@ -40,17 +42,13 @@ def do_vllm_serve(
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = 1
data_parallel_size = 1
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = (
@@ -83,3 +81,63 @@ def do_vllm_serve(
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)
def patch_vllm_worker():
from multiprocessing.connection import Connection
from vllm import LLM
def llm_worker(
script_args: AxolotlScriptArguments,
data_parallel_rank: int,
master_port: int,
connection: Connection,
) -> None:
# Set required environment variables for DP to work with vLLM
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
llm = LLM(
model=script_args.model,
revision=script_args.revision,
tensor_parallel_size=script_args.tensor_parallel_size,
gpu_memory_utilization=script_args.gpu_memory_utilization,
enforce_eager=script_args.enforce_eager,
dtype=script_args.dtype,
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
enable_reasoning=script_args.enable_reasoning,
reasoning_parser=script_args.reasoning_parser,
)
# Send ready signal to parent process
connection.send({"status": "ready"})
while True:
# Wait for commands from the parent process
try:
command = connection.recv()
except KeyboardInterrupt:
llm.collective_rpc(method="close_communicator")
break
# Handle commands
if command["type"] in ["call", "fire_and_forget"]:
method_name = command["method"]
args, kwargs = command.get("args", ()), command.get("kwargs", {})
method = getattr(llm, method_name)
result = method(*args, **kwargs)
if command["type"] == "call":
connection.send(result)
elif command["type"] == "shutdown":
break
trl.scripts.vllm_serve.llm_worker = llm_worker

View File

@@ -13,5 +13,4 @@ MOE_ARCH_BLOCK = {
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"gpt_oss": "GptOssDecoderLayer",
}

View File

@@ -24,10 +24,12 @@ from pathlib import Path
from typing import Any
import torch
from accelerate import PartialState
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from transformers.training_args import OptimizerNames
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
@@ -38,7 +40,6 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
@@ -266,24 +267,27 @@ class TrainerBuilderBase(abc.ABC):
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
DionOptimizerFactory,
)
optimizer_cls = DionOptimizerFactory
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
optimizer_kwargs.update(adam_kwargs)
_, device_mesh = build_parallelism_config(self.cfg)
if device_mesh is not None:
optimizer_kwargs["device_mesh"] = device_mesh
elif self.cfg.optimizer == "optimi_adamw":
from optimi import AdamW
optimizer_kwargs["foreach"] = False
optimizer_cls = AdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_4bit":
# TODO remove 20250401
from torchao.prototype.low_bit_optim import AdamW4bit
optimizer_cls = AdamW4bit
optimizer_kwargs.update(adam_kwargs)
LOG.warning(
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
)
elif self.cfg.optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
optimizer_cls = AdamW8bit
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
@@ -429,12 +433,30 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
partial_state = PartialState()
has_pc_attr = (
hasattr(partial_state, "parallelism_config")
and partial_state.parallelism_config
)
has_pc_key = (
"parallelism_config"
in partial_state._shared_state # pylint: disable=protected-access
and partial_state._shared_state[ # pylint: disable=protected-access
"parallelism_config"
]
)
use_configured_state = has_pc_attr or has_pc_key
if self.cfg.accelerator_config:
use_configured_state = self.cfg.accelerator_config.pop(
"use_configured_state", use_configured_state
)
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
**self.cfg.accelerator_config
use_configured_state=use_configured_state, **self.cfg.accelerator_config
)
else:
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
use_configured_state=use_configured_state,
)
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.activation_offloading is True:
@@ -494,20 +516,10 @@ class TrainerBuilderBase(abc.ABC):
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
"dion_rank_fraction",
"dion_rank_multiple_of",
]:
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
training_args_kwargs[arg] = getattr(self.cfg, arg)
arg_map = {
"dion_learning_rate": "dion_lr",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
training_args_kwargs["average_tokens_across_devices"] = False

View File

@@ -43,7 +43,6 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -137,18 +136,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
return trainer_cls
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return AxolotlTrainer
def build(self, total_num_steps):
@@ -363,7 +350,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
self.cfg.sequence_len / multiple
)
elif self.cfg.pad_to_sequence_len is None:
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple

View File

@@ -15,7 +15,6 @@ from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -73,16 +72,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):

View File

@@ -5,6 +5,7 @@
from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .trl import (
AxolotlCPOTrainer,

View File

@@ -10,11 +10,8 @@ from functools import partial, wraps
from typing import Any, Callable, Literal, Optional
import datasets
import safetensors
import torch
from accelerate.state import AcceleratorState
from datasets import Dataset
from peft import PeftModel
from torch.utils.data import (
BatchSampler,
DataLoader,
@@ -22,10 +19,8 @@ from torch.utils.data import (
Sampler,
SequentialSampler,
)
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -520,18 +515,7 @@ class AxolotlTrainer(
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
accelerator_config = self.args.accelerator_config.to_dict()
use_configured_state = accelerator_config.get("use_configured_state", False)
if not use_configured_state:
AcceleratorState._reset_state( # pylint: disable=protected-access
reset_partial_state=True
)
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
res = super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
@@ -540,6 +524,8 @@ class AxolotlTrainer(
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
return res
# pylint: disable=unused-argument
def additional_accelerator_args(
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
@@ -581,10 +567,10 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
logs["memory/max_memory_active"] = active
logs["memory/max_memory_allocated"] = allocated
logs["memory/device_memory_reserved"] = reserved
except (ValueError, FileNotFoundError):
pass
del self._stored_metrics[train_eval]
@@ -604,64 +590,3 @@ class AxolotlTrainer(
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
else (PreTrainedModel, PeftModel)
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
):
self.accelerator.unwrap_model(
self.model, keep_torch_compile=False
).save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -2,7 +2,6 @@
Mixin for correctly saving fsdp
"""
from accelerate import PartialState
from transformers import Trainer
@@ -19,15 +18,3 @@ class DistributedParallelMixin(Trainer):
):
state_dict = self.accelerator.get_state_dict(self.model)
super()._save(output_dir, state_dict=state_dict)
def create_accelerator_and_postprocess(self):
super().create_accelerator_and_postprocess()
if (
self.accelerator.distributed_type == "FSDP"
and self.accelerator.state.fsdp_plugin is None
):
# pylint: disable=protected-access
# handle Context Parallelism without FSDP
self.accelerator.state.distributed_type = "MULTI_GPU"
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
PartialState().distributed_type = "MULTI_GPU"

View File

@@ -243,18 +243,3 @@ class AxolotlTrainingMixins:
)
# end of multi-modal section
dion_learning_rate: float | None = field(
default=None,
metadata={"help": "The learning rate for Dion"},
)
dion_momentum: float | None = field(
default=None,
metadata={"help": "The momentum for Dion"},
)
dion_rank_fraction: float | None = field(
default=None,
)
dion_rank_multiple_of: int | None = field(
default=None,
)

View File

@@ -26,11 +26,9 @@ import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -76,8 +74,8 @@ class BasePlugin:
def __init__(self):
"""Initializes the BasePlugin."""
def register(self, cfg: dict): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration as an unparsed dict.
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
"""Registers the plugin with the given configuration.
Args:
cfg: The configuration for the plugin.
@@ -643,24 +641,3 @@ class BaseOptimizerFactory:
self, opt_model, training_args, **optimizer_kwargs
) -> Optimizer | None:
pass
# duplicated from transformers
def get_decay_parameter_names(self, model) -> list[str]:
"""
Get all parameter names that weight decay will be applied to.
This function filters out parameters in two ways:
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
2. By parameter name patterns (containing 'bias', or variation of 'norm')
"""
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_parameters = get_parameter_names(
model, [nn.LayerNorm], forbidden_name_patterns
)
return decay_parameters

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
```
## Usage
@@ -31,7 +31,6 @@ plugins:
## Supported Models
- arcee
- cohere
- cohere2
- gemma
@@ -42,17 +41,13 @@ plugins:
- gemma3n_text
- glm
- glm4
- gpt_oss
- granite
- granitemoe
- hunyuan_v1_dense
- hunyuan_v1_moe
- llama
- llama4
- llama4_text
- mistral
- mistral3
- mixtral
- mllama
- phi
- phi3

View File

@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
)

View File

@@ -284,12 +284,12 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
return sample
def _tokenize_single_prompt(self, prompt):
target_token_ids = prompt.get("target_token_ids", None)
logprobs = prompt.pop(self.logprobs_field)
target_token_ids = prompt.pop("target_token_ids")
tokenized_prompt = super()._tokenize_single_prompt(prompt)
if target_token_ids is not None:
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt

View File

@@ -14,7 +14,6 @@ from typing import Callable
import torch
from bitsandbytes.functional import QuantState
from torch import nn
from torch.distributed.tensor import DTensor
from .geglu import geglu_backward, geglu_forward
from .quantize import dequantize
@@ -26,7 +25,6 @@ def get_lora_parameters(
proj: nn.Module,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
QuantState | None,
torch.Tensor | None,
torch.Tensor | None,
@@ -39,54 +37,39 @@ def get_lora_parameters(
proj: The projection module to extract parameters from.
Returns:
A tuple containing the base weights, quantization state, LoRA A and B weights,
scaling factor, and base layer bias. Quant state, weights, and bias may be
`None` if not available.
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
LoRA B matrix, and scaling factor. States and matrices may be None if not
available.
"""
# For DPO or disabled adapters
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
W = base_layer.weight
b = base_layer.bias
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
quant_state = getattr(W, "quant_state", None)
return W, b, quant_state, None, None, None
quant_state = getattr(W, "quant_state", None)
return W, quant_state, None, None, None
active_adapter = (
proj.active_adapters[0]
if hasattr(proj, "active_adapters")
else proj.active_adapter
)
linear_A = proj.lora_A[active_adapter]
linear_B = proj.lora_B[active_adapter]
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
# We fuse linear layers + LoRA adapters calculations into a single
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
# Note that we don't apply resharding later in this module (it gets messy quickly),
# but LoRA parameters are generally small enough that this is not an issue.
if isinstance(linear_A.weight, DTensor):
linear_A.unshard()
linear_B.unshard()
A = linear_A.weight
B = linear_B.weight
A = proj.lora_A[active_adapter].weight
B = proj.lora_B[active_adapter].weight
s = proj.scaling[active_adapter]
return W, b, quant_state, A, B, s
quant_state = getattr(W, "quant_state", None)
return W, quant_state, A, B, s
def matmul_lora(
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor | None,
W_quant: QuantState | None,
A: torch.Tensor | None,
B: torch.Tensor | None,
s: float | None,
W_quant: QuantState,
A: torch.Tensor,
B: torch.Tensor,
s: float,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""
@@ -107,22 +90,20 @@ def matmul_lora(
dtype = X.dtype
W = dequantize(W.t(), W_quant)
reshape = False
if X.dim() == 3:
batch, seq_len, _ = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
out = torch.matmul(X, W, out=out)
if W_quant is not None:
del W
if A is not None:
A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
out += s * X @ A @ B
if b is not None:
out += b
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
return out.view(batch, seq_len, -1) if reshape else out
@@ -136,20 +117,17 @@ class LoRA_MLP(torch.autograd.Function):
ctx,
X: torch.Tensor,
gate_weight: torch.Tensor,
gate_bias: torch.Tensor | None,
gate_quant: QuantState | None,
gate_quant: object | None,
gate_A: torch.Tensor | None,
gate_B: torch.Tensor | None,
gate_scale: float,
up_weight: torch.Tensor,
up_bias: torch.Tensor | None,
up_quant: QuantState | None,
up_quant: object | None,
up_A: torch.Tensor | None,
up_B: torch.Tensor | None,
up_scale: float,
down_weight: torch.Tensor,
down_bias: torch.Tensor | None,
down_quant: QuantState | None,
down_quant: object | None,
down_A: torch.Tensor | None,
down_B: torch.Tensor | None,
down_scale: float,
@@ -164,22 +142,20 @@ class LoRA_MLP(torch.autograd.Function):
ctx: Autograd context
X: Input features
gate_weight: Gate projection weight
gate_bias: Gate projection bias
gate_quant: Gate quantization state
gate_A: Gate LoRA A matrix
gate_B: Gate LoRA B matrix
gate_scale: Gate LoRA scale
up_weight: Up projection weight
up_quant: Up projection quantization state
up_A: Up projection LoRA A matrix
up_B: Up projection LoRA B matrix
up_scale: Up projection LoRA scale
down_weight: Down projection weight
down_bias: Down projection bias
down_quant: Down projection quantization state
down_A: Down projection LoRA A matrix
down_B: Down projection LoRA B matrix
down_scale: Down projection LoRA scale
up_weight: Up-projection weight
up_quant: Up-projection quantization state
up_A: Up-projection LoRA A matrix
up_B: Up-projection LoRA B matrix
up_scale: Up-projection LoRA scale
down_weight: Down-projection weight
down_quant: Down-projection quantization state
down_A: Down-projection LoRA A matrix
down_B: Down-projection LoRA B matrix
down_scale: Down-projection LoRA scale
activation_fn: Forward activation function
activation_fn_backward: Backward activation function
inplace: Whether to perform operations in-place
@@ -188,17 +164,15 @@ class LoRA_MLP(torch.autograd.Function):
Output transformed by multi-layer perceptron and activation function
"""
# Compute projections
gate = matmul_lora(
X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
)
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
# Activation
hidden = activation_fn(gate, up)
# Down projection
output = matmul_lora(
hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
hidden, down_weight, down_quant, down_A, down_B, down_scale
)
# Save for backward
@@ -221,26 +195,22 @@ class LoRA_MLP(torch.autograd.Function):
torch.Tensor | None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
None,
]:
"""
Performs backward pass computation for LoRA MLP.
@@ -252,7 +222,7 @@ class LoRA_MLP(torch.autograd.Function):
Returns:
Tuple containing gradients for all inputs from forward pass:
- Input gradient tensor (or `None`)
- `None` for weights/biases/quantization states
- `None` for weights/quantization states
- LoRA A/B matrix gradients (or `None`)
- `None` for scaling factors
- `None` for activation functions and flags
@@ -295,10 +265,9 @@ class LoRA_MLP(torch.autograd.Function):
dtype = X.dtype
# Down projection
grad_down = matmul_lora(
DW = matmul_lora(
grad_output,
down_weight.t(),
None,
down_quant,
down_B,
down_A,
@@ -306,7 +275,7 @@ class LoRA_MLP(torch.autograd.Function):
)
# Activation backward
h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
# Initialize and compute LoRA gradients
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
@@ -346,8 +315,8 @@ class LoRA_MLP(torch.autograd.Function):
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
# Gate projection gradients
gate_weight = dequantize(gate_weight, gate_quant)
dX += grad_gate @ gate_weight
gate_weight = dequantize(gate_weight.t(), gate_quant)
dX += grad_gate @ gate_weight.t()
del gate_weight
if gate_A is not None and gate_B is not None:
@@ -365,26 +334,22 @@ class LoRA_MLP(torch.autograd.Function):
dX,
None,
None,
None,
d_gate_A.t() if d_gate_A is not None else None,
d_gate_B.t() if d_gate_B is not None else None,
None,
None,
None,
None,
d_up_A.t() if d_up_A is not None else None,
d_up_B.t() if d_up_B is not None else None,
None,
None,
None,
None,
d_down_A.t() if d_down_A is not None else None,
d_down_B.t() if d_down_B is not None else None,
None,
None,
None,
None,
None,
)
@@ -399,26 +364,23 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
Returns:
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
"""
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateb,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upb,
upW_quant,
upA,
upB,
upS,
downW,
downb,
downW_quant,
downA,
downB,
@@ -442,25 +404,22 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
Returns:
Output tensor after applying LoRA-adapted MLP with GEGLU activation
"""
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
out = LoRA_MLP.apply(
X,
gateW,
gateb,
gateW_quant,
gateA,
gateB,
gateS,
upW,
upb,
upW_quant,
upA,
upB,
upS,
downW,
downb,
downW_quant,
downA,
downB,
@@ -487,19 +446,16 @@ class LoRA_QKV(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
q_weight: torch.Tensor,
q_bias: torch.Tensor | None,
q_quant: QuantState | None,
q_A: torch.Tensor | None,
q_B: torch.Tensor | None,
q_scale: float,
k_weight: torch.Tensor,
k_bias: torch.Tensor | None,
k_quant: QuantState | None,
k_A: torch.Tensor | None,
k_B: torch.Tensor | None,
k_scale: float,
v_weight: torch.Tensor,
v_bias: torch.Tensor | None,
v_quant: QuantState | None,
v_A: torch.Tensor | None,
v_B: torch.Tensor | None,
@@ -513,19 +469,16 @@ class LoRA_QKV(torch.autograd.Function):
ctx: Autograd context
X: Input tensor
q_weight: Query projection weight
q_bias: Query projection bias
q_quant: Query quantization state
q_A: Query LoRA A matrix
q_B: Query LoRA B matrix
q_scale: Query LoRA scale
k_weight: Key projection weight
k_bias: Key projection bias
k_quant: Key quantization state
k_A: Key LoRA A matrix
k_B: Key LoRA B matrix
k_scale: Key LoRA scale
v_weight: Value projection weight
v_bias: Value projection bias
v_quant: Value quantization state
v_A: Value LoRA A matrix
v_B: Value LoRA B matrix
@@ -535,21 +488,20 @@ class LoRA_QKV(torch.autograd.Function):
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
ctx.scales = (q_scale, k_scale, v_scale)
ctx.quants = (q_quant, k_quant, v_quant)
ctx.weights = (q_weight, k_weight, v_weight)
ctx.biases = (q_bias, k_bias, v_bias)
ctx.inplace = inplace
return Q, K, V
@staticmethod
@torch_amp_custom_bwd
@torch_amp_custom_fwd
def backward(
ctx: torch.autograd.function.FunctionCtx,
q_grad: torch.Tensor,
@@ -559,19 +511,16 @@ class LoRA_QKV(torch.autograd.Function):
torch.Tensor,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
None,
None,
None,
torch.Tensor | None,
torch.Tensor | None,
None,
@@ -659,31 +608,31 @@ class LoRA_QKV(torch.autograd.Function):
# Transpose gradients if needed
if d_A_q is not None:
d_A_q = d_A_q.t()
d_B_q = d_B_q.t() # type: ignore[union-attr]
if d_B_q is not None:
d_B_q = d_B_q.t()
if d_A_k is not None:
d_A_k = d_A_k.t()
d_B_k = d_B_k.t() # type: ignore[union-attr]
if d_B_k is not None:
d_B_k = d_B_k.t()
if d_A_v is not None:
d_A_v = d_A_v.t()
d_B_v = d_B_v.t() # type: ignore[union-attr]
if d_B_v is not None:
d_B_v = d_B_v.t()
return (
grad_X.view(batch, seq_len, -1),
None,
None,
None,
d_A_q,
d_B_q,
None,
None,
None,
None,
d_A_k,
d_B_k,
None,
None,
None,
None,
d_A_v,
d_B_v,
None,
@@ -704,25 +653,22 @@ def apply_lora_qkv(
Returns:
Tuple of (Query, Key, Value) projection tensors
"""
QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(
X,
QW,
Qb,
QW_quant,
QA,
QB,
QS,
KW,
Kb,
KW_quant,
KA,
KB,
KS,
VW,
Vb,
VW_quant,
VA,
VB,
@@ -742,11 +688,10 @@ class LoRA_O(torch.autograd.Function):
ctx: torch.autograd.function.FunctionCtx,
X: torch.Tensor,
W: torch.Tensor,
b: torch.Tensor,
W_quant: QuantState | None,
A: torch.Tensor,
B: torch.Tensor,
s: float,
A: torch.Tensor | None,
B: torch.Tensor | None,
S: float,
) -> torch.Tensor:
"""
Forward pass for output projection with LoRA.
@@ -755,20 +700,19 @@ class LoRA_O(torch.autograd.Function):
ctx: Autograd context
X: Input tensor
W: Output projection weight
b: Output projection bias
W_quant: Weight quantization state
A: LoRA A matrix
B: LoRA B matrix
s: LoRA scaling factor
S: LoRA scaling factor
Returns:
Output projection result
Output projection tensor
"""
XW = matmul_lora(X, W, b, W_quant, A, B, s)
XW = matmul_lora(X, W, W_quant, A, B, S)
ctx.custom_saved_tensors = (
W,
W_quant,
s,
S,
)
ctx.save_for_backward(A, B, X)
@@ -783,9 +727,8 @@ class LoRA_O(torch.autograd.Function):
torch.Tensor,
None,
None,
None,
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
torch.Tensor | None,
None,
]:
"""
@@ -798,7 +741,7 @@ class LoRA_O(torch.autograd.Function):
Returns:
Tuple containing gradients for all forward inputs
"""
W, W_quant, s = ctx.custom_saved_tensors
W, W_quant, S = ctx.custom_saved_tensors
A, B, X = ctx.saved_tensors
batch, seq_len, hd = X.shape
@@ -808,19 +751,17 @@ class LoRA_O(torch.autograd.Function):
# Weight projection
dY_X = X.t() @ dY
d_A = s * dY_X @ B
d_B = s * A @ dY_X
d_A = S * dY_X @ B
d_B = S * A @ dY_X
# Get derivative for dX
W = dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
A, B = A.to(dtype), B.to(dtype)
dX += s * dY @ B @ A
# W, b, W_quant, A, B, s
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
# W, W_quant, A, B, S
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
@@ -833,7 +774,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
Returns:
Transformed output tensor
"""
OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
return output

View File

@@ -76,7 +76,6 @@ def load_lora(
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
@@ -107,7 +106,6 @@ def load_lora(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
target_parameters=lora_target_parameters,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,

View File

@@ -1,13 +1,26 @@
"""Shared constants for axolotl.loaders module"""
from transformers import AutoModelForImageTextToText
from transformers.models.auto.modeling_auto import (
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
from transformers import (
Gemma3ForConditionalGeneration,
Gemma3nForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Mistral3ForConditionalGeneration,
MllamaForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
MULTIMODAL_AUTO_MODEL_MAPPING = {
"mllama": MllamaForConditionalGeneration,
"llama4": Llama4ForConditionalGeneration,
"llava": LlavaForConditionalGeneration,
"qwen2_vl": Qwen2VLForConditionalGeneration,
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
"mistral3": Mistral3ForConditionalGeneration,
"gemma3": Gemma3ForConditionalGeneration,
"gemma3n": Gemma3nForConditionalGeneration,
}
try:
from transformers import VoxtralForConditionalGeneration

View File

@@ -1,5 +1,5 @@
"""
Model loader class implementation for loading, configuring, and patching various models.
"""Model loader class implementation for loading, configuring, and patching various
models.
"""
import gc
@@ -13,7 +13,7 @@ import peft
import torch
import transformers
import transformers.modeling_utils
from accelerate import init_empty_weights
from accelerate import PartialState, init_empty_weights
from accelerate.parallelism_config import ParallelismConfig
from peft import (
PeftConfig,
@@ -22,10 +22,8 @@ from peft import (
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
@@ -51,11 +49,7 @@ from axolotl.loaders.utils import (
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
build_parallelism_config,
get_device_count,
get_device_type,
)
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
@@ -93,7 +87,6 @@ class ModelLoader:
use_parallel_config: bool | None = False
parallelism_config: ParallelismConfig | None = None
device_mesh: DeviceMesh | None = None
def __init__(
self,
@@ -209,11 +202,8 @@ class ModelLoader:
self._set_device_map_config()
if self.cfg.revision_of_model:
self.model_kwargs["revision"] = self.cfg.revision_of_model
if self.cfg.use_kernels:
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
self._set_quantization_config()
self._set_attention_config()
self._check_model_requirements()
def _apply_post_model_load_setup(self):
"""Configure the model after it has been loaded."""
@@ -310,10 +300,7 @@ class ModelLoader:
)
# Handle DeepSpeed Zero3
if (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
):
if is_deepspeed_zero3_enabled():
self._set_z3_leaf_modules()
# Apply gradient checkpointing if needed
@@ -418,12 +405,85 @@ class ModelLoader:
gc.collect()
torch.cuda.empty_cache()
@staticmethod
def _get_parallel_config_kwargs(
world_size: int,
tensor_parallel_size: int = 1,
context_parallel_size: int = 1,
dp_shard_size: int | None = None,
dp_replicate_size: int | None = None,
is_fsdp: bool = False,
):
pc_kwargs = {}
remaining_world_size = world_size
if tensor_parallel_size and tensor_parallel_size > 1:
pc_kwargs["tp_size"] = tensor_parallel_size
remaining_world_size = remaining_world_size // tensor_parallel_size
if context_parallel_size and context_parallel_size > 1:
pc_kwargs["cp_size"] = context_parallel_size
remaining_world_size = remaining_world_size // context_parallel_size
if dp_shard_size is None and dp_replicate_size in (None, 1):
if remaining_world_size > 1:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if dp_replicate_size and dp_replicate_size > 1:
pc_kwargs["dp_replicate_size"] = dp_replicate_size
remaining_world_size = remaining_world_size // dp_replicate_size
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
if not is_fsdp:
raise ValueError(
"dp_shard_size was configured without a corresponding fsdp_config! "
"Please ensure you have configured FSDP using fsdp_config."
)
pc_kwargs["dp_shard_size"] = dp_shard_size
remaining_world_size = remaining_world_size // dp_shard_size
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
pc_kwargs["dp_replicate_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
if "dp_shard_size" not in pc_kwargs and is_fsdp:
pc_kwargs["dp_shard_size"] = remaining_world_size
remaining_world_size = 1
if remaining_world_size > 1:
raise ValueError(
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
f"{pc_kwargs}"
)
return pc_kwargs
def _set_parallel_config(self):
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
if parallelism_config:
self.parallelism_config = parallelism_config
self.device_mesh = device_mesh
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
get_world_size(),
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
self.cfg.dp_shard_size,
self.cfg.dp_replicate_size,
bool(self.cfg.fsdp or self.cfg.fsdp_config),
)
if pc_kwargs:
self.parallelism_config = ParallelismConfig(
**pc_kwargs,
)
device_mesh = self.parallelism_config.build_device_mesh("cuda")
partial_state = PartialState()
# fmt: off
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
self.parallelism_config
)
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
device_mesh
)
# fmt: on
def _set_auto_model_loader(self):
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
@@ -434,8 +494,6 @@ class ModelLoader:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForVision2Seq
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
def _set_device_map_config(self):
"""Setup `device_map` according to config"""
@@ -507,17 +565,8 @@ class ModelLoader:
def _set_quantization_config(self):
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
if self.cfg.model_quantization_config == "Mxfp4Config":
from transformers import Mxfp4Config
mxfp4_kwargs = {}
if self.cfg.model_quantization_config_kwargs:
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
else:
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
if self.cfg.gptq:
if not hasattr(self.model_config, "quantization_config"):
@@ -552,9 +601,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
"load_in_4bit", False
):
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -580,9 +627,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
"load_in_8bit", False
):
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
@@ -603,9 +648,7 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
if self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flex_attention"
@@ -632,16 +675,6 @@ class ModelLoader:
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
def _check_model_requirements(self):
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
from transformers.utils.import_utils import is_causal_conv1d_available
if is_causal_conv1d_available():
raise ImportError(
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
)
def _configure_zero3_memory_efficient_loading(
self,
) -> HfTrainerDeepSpeedConfig | None:
@@ -688,7 +721,7 @@ class ModelLoader:
if self.cfg.tensor_parallel_size > 1:
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
self.model_kwargs["tp_plan"] = "auto"
self.model_kwargs["device_mesh"] = self.device_mesh
self.model_kwargs["device_mesh"] = PartialState().device_mesh
if "device_map" in self.model_kwargs:
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
@@ -704,18 +737,6 @@ class ModelLoader:
elif self.is_qlora_and_fsdp_enabled:
skip_move_to_device = True
if (
self.cfg.tensor_parallel_size <= 1
and self.cfg.fsdp_config.cpu_ram_efficient_loading
and self.cfg.fsdp_version == 2
):
# setting device_map for TP is not supported
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if local_rank == 0:
self.model_kwargs["device_map"] = "cpu"
else:
self.model_kwargs["device_map"] = "meta"
if (
self.is_qlora_and_fsdp_enabled
and self.cfg.fsdp_config.cpu_ram_efficient_loading
@@ -824,9 +845,6 @@ class ModelLoader:
self.model._tp_size = self.cfg.tensor_parallel_size
self.model._device_mesh = self.model_kwargs["device_mesh"]
if self.cfg.experimental_skip_move_to_device is not None:
skip_move_to_device = self.cfg.experimental_skip_move_to_device
return skip_move_to_device
def _set_z3_leaf_modules(self):

View File

@@ -65,7 +65,6 @@ class PatchManager:
self._patch_llama_derived_model()
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -73,19 +72,11 @@ class PatchManager:
self._apply_voxtral_patches()
def _apply_transformers_patches(self):
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
patch_evaluation_loop,
patch_maybe_log_save_evaluate,
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
patch_prepare_from_posids,
)
patch_fsdp2 = (
self.cfg.torch_compile
and self.cfg.fsdp_config
and self.cfg.fsdp_version == 2
)
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()
patch_prepare_from_posids()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -112,14 +103,6 @@ class PatchManager:
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
if self.cfg.context_parallel_size > 1 or (
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
):
from axolotl.monkeypatch.accelerate.parallelism_config import (
patch_parallelism_config,
)
patch_parallelism_config()
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
@@ -277,21 +260,6 @@ class PatchManager:
has_remote_code=has_remote_code,
)
def _apply_fsdp2_bnb_patches(self):
"""Apply FSDP2 BNB patches."""
if (
self.cfg.fsdp_config
and str(self.cfg.fsdp_version) == "2"
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import (
@@ -362,21 +330,31 @@ class PatchManager:
patch_self_attn_lora()
def _patch_llama_flash_attention(self):
def _patch_llama_flash_attention(self, packed=False):
"""Apply Flash Attention patches for LLaMA models."""
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
if self.cfg.s2_attention:
if packed:
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
LOG.info("patching with flash attention for sample packing")
replace_llama_attn_with_flash_attn(
packed=True,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.s2_attention:
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
use_shifted_sparse_attn=True,
)
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
replace_llama_attn_with_flash_attn(
packed=False,
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
@@ -407,7 +385,7 @@ class PatchManager:
and self.cfg.sample_packing
):
if self.cfg.flash_attention:
self._patch_llama_flash_attention()
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
elif self.cfg.xformers_attention:
self._patch_llama_xformers_attention()
elif self.cfg.sample_packing:
@@ -430,12 +408,17 @@ class PatchManager:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
if self.cfg.flash_attn_fuse_qkv:
LOG.info("Patching with fused QKV...")
replace_llama_qkv_with_fused(model)
def _apply_unsloth_patches(self, model):
"""Apply unsloth optimization patches."""
if self.cfg.unsloth_lora_mlp:

View File

@@ -7,7 +7,6 @@ import functools
import sys
import torch
import torch.distributed as dist
from torch import nn
from axolotl.utils.bench import log_gpu_memory_usage
@@ -37,49 +36,25 @@ def fsdp2_load_full_state_dict(
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, sharded_meta_param in meta_sharded_sd.items():
full_tensor = None
if _accelerator.is_main_process:
full_tensor = full_sd[param_name]
full_tensor = full_tensor.to(sharded_meta_param.dtype)
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
if hasattr(sharded_meta_param, "device_mesh"):
device_mesh = sharded_meta_param.device_mesh
if _accelerator.is_main_process:
full_tensor = full_tensor.to(device_mesh.device_type)
else:
full_tensor = torch.empty(
sharded_meta_param.size(),
device=device_mesh.device_type,
dtype=sharded_meta_param.dtype,
)
sharded_param = distribute_tensor(
full_tensor,
device_mesh,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
src_data_rank=0,
)
else:
# Non-sharded parameters
if _accelerator.is_main_process:
sharded_param = full_tensor.to(torch.device("cuda"))
else:
# broadcast manually
sharded_param = torch.empty_like(
sharded_meta_param,
device=torch.device("cuda"),
dtype=sharded_meta_param.dtype,
)
dist.broadcast(sharded_param, src=0)
sharded_param = full_tensor
if offload_to_cpu:
sharded_param = sharded_param.cpu()
sharded_sd[param_name] = nn.Parameter(sharded_param)
del full_tensor
full_sd[param_name] = None
model.load_state_dict(sharded_sd, assign=True, strict=True)
end_time = time.time()
LOG.debug(
@@ -187,7 +162,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(

View File

@@ -1,77 +0,0 @@
"""
workaround to allow parallelism config for pure CP
"""
# pylint: disable=protected-access
import os
import warnings
from accelerate import DistributedType
def _validate_accelerator(self, accelerator):
_warnings = set()
if not accelerator.multi_device and self.total_size == 1:
# No distributed setup, valid parallelism config
return
# We need this to ensure DDP works
if self.total_size == 1:
self._set_size("dp_replicate", accelerator.num_processes)
if self.total_size != accelerator.num_processes:
raise ValueError(
f"ParallelismConfig total_size ({self.total_size}) does not match "
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
f"dp_shard_size/tp_size/cp_size."
)
# allow parallelism config when not using fsdp if using pure context parallelism
allow_parallelism_config = False
if (
self.cp_size > 1 # pylint: disable=chained-comparison
and self.dp_shard_size <= 1
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
):
allow_parallelism_config = True
if (
self.total_size > 1
and not allow_parallelism_config
and not (accelerator.is_fsdp2 or accelerator.multi_device)
):
raise ValueError(
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
)
for parallelism, size in self._sizes.items():
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
_warnings.add(
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
)
if _warnings and accelerator.is_main_process:
warnings.warn(
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
UserWarning,
)
def patched_is_fsdp2(self) -> bool:
"""
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
"""
# The new logic checks if fsdp_plugin exists before accessing its attributes
return (
self.distributed_type == DistributedType.FSDP
and self.fsdp_plugin
and self.fsdp_plugin.fsdp_version == 2
)
def patch_parallelism_config():
from accelerate.accelerator import AcceleratorState, ParallelismConfig
ParallelismConfig._validate_accelerator = _validate_accelerator
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)

View File

@@ -1,144 +0,0 @@
"""
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
our LoRA / QLoRA Triton kernels to work with FSDP2.
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
Params4bit parameters.
"""
import importlib
import inspect
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# pylint: disable=protected-access
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
original_source = inspect.getsource(FSDPParam._init_sharded_param)
original_source, _ = detab_code(original_source)
# Define the replacement
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
patched_param_creation = """ import bitsandbytes as bnb
if isinstance(param, bnb.nn.modules.Params4bit):
self.sharded_param = bnb.nn.modules.Params4bit(
data=sharded_param,
requires_grad=param.requires_grad,
quant_state=param.quant_state,
blocksize=param.blocksize,
compress_statistics=param.compress_statistics,
quant_type=param.quant_type,
quant_storage=param.quant_storage,
module=param.module,
bnb_quantized=param.bnb_quantized,
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
# Apply the replacement
if original_param_creation in original_source:
patched_source = original_source.replace(
original_param_creation, patched_param_creation
)
patched_source = patched_source.replace(
"def _init_sharded_param(",
"def patched_init_sharded_param(",
1,
)
# Load necessary imports
module_name = FSDPParam.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
# Replace the method
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
LOG.info("Successfully applied FSDP _init_sharded_param patch")
else:
LOG.warning("Could not find target code for _init_sharded_param patching")
def apply_init_unsharded_param_patch():
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# Get original source
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
original_source, _ = detab_code(original_source)
# Define the replacement
original_param_creation = """ self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)"""
patched_param_creation = """ import bitsandbytes as bnb
local_tensor = self.sharded_param._local_tensor
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
self._unsharded_param = bnb.nn.modules.Params4bit(
data=unsharded_param,
requires_grad=self.sharded_param.requires_grad,
quant_state=local_tensor.quant_state,
blocksize=local_tensor.blocksize,
compress_statistics=local_tensor.compress_statistics,
quant_type=local_tensor.quant_type,
quant_storage=local_tensor.quant_storage,
module=local_tensor.module,
bnb_quantized=local_tensor.bnb_quantized,
)
else:
self._unsharded_param = nn.Parameter(
unsharded_param, requires_grad=self.sharded_param.requires_grad
)"""
# Apply the replacement
if original_param_creation in original_source:
patched_source = original_source.replace(
original_param_creation, patched_param_creation
)
patched_source = patched_source.replace(
"def init_unsharded_param(",
"def patched_init_unsharded_param(",
1,
)
# Load necessary imports
module_name = FSDPParam.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
# Replace the method
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
LOG.info("Successfully applied FSDP init_unsharded_param patch")
else:
LOG.warning("Could not find target code for patching")

View File

@@ -3,26 +3,39 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import (
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import set_module_name
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.utils.logging import get_logger
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
@@ -69,6 +82,19 @@ def replace_llama_mlp_with_swiglu(model):
set_module_name(model, name, mlp)
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)
def patch_fa_llama_cross_entropy():
LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
@@ -116,6 +142,7 @@ def patch_llama_rms_norm():
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
use_shifted_sparse_attn: Optional[bool] = False,
@@ -127,6 +154,16 @@ def replace_llama_attn_with_flash_attn(
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward_with_s2attn
)
else:
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
flashattn_forward
)
if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward
)
# skip only if explicitly disabled
if cross_entropy:
@@ -137,6 +174,49 @@ def replace_llama_attn_with_flash_attn(
patch_llama_rms_norm()
class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device
# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o
# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)
def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)
new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj
new_attn.o_proj.weight.data = self.o_proj.weight.data
set_module_name(model, name, new_attn)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
@@ -275,3 +355,576 @@ def flashattn_forward_with_s2attn(
.reshape(bsz, q_len, nheads, self.head_dim)
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
if isinstance(self, FusedAttention):
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
self.out_features, dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# flash-attn v2 start
#
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
#
# flash-attn v2 end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
def output_pad_fn(output_unpad):
return pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad):
return rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(
*inputs,
)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
"""
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -156,11 +156,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return Llama4TextAttention
if model_type == "mistral3":
from transformers.models.mistral.modeling_mistral import MistralAttention
return MistralAttention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
@@ -395,6 +390,7 @@ def apply_lora_kernel_patches(
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -404,8 +400,7 @@ def apply_lora_kernel_patches(
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
)
if cfg.lora_o_kernel:
# Output patching
@@ -414,6 +409,7 @@ def apply_lora_kernel_patches(
]
can_patch_o = all(
hasattr(module, "lora_A")
and getattr(module, "base_layer", module).bias is None
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
@@ -422,14 +418,14 @@ def apply_lora_kernel_patches(
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and getattr(proj, "base_layer", proj).bias is None
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
@@ -439,8 +435,7 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
)
LOG.setLevel(original_level)

View File

@@ -3,14 +3,53 @@
# pylint: disable=duplicate-code
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def replace_mistral_attn_with_flash_attn(
packed: Optional[bool] = False,
):
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward
)
if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer
)
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
mistral_model_forward
)
def patch_mistral_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
@@ -18,3 +57,604 @@ def patch_mistral_cross_entropy():
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
sliding_window,
): # pylint: disable=unused-argument
# [bsz, seq_len]
if attention_mask is None or sliding_window is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask
def flashattn_forward(
self: OriginalMistralAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
use_sliding_windows = (
getattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if (
hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
window_size=window_size,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
window_size=window_size,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
window_size=window_size,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
def output_pad_fn(output_unpad):
return pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad):
return rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def mistral_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = (
self._gradient_checkpointing_func( # pylint: disable=protected-access
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
cu_seqlens,
max_seqlen,
)
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class MistralDecoderLayer(OriginalMistralDecoderLayer):
"""
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

View File

@@ -36,8 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"glm",
"glm4",
"smollm3",
"gpt_oss",
"arcee",
"granite",
"granitemoe",
]

View File

@@ -20,15 +20,12 @@ from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try: # pylint: disable=duplicate-code
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

View File

@@ -15,15 +15,12 @@ import torch
import torch.distributed as dist
from torch.distributed import DeviceMesh
try: # pylint: disable=duplicate-code
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger

View File

@@ -0,0 +1,78 @@
"""
fix for FSDP2 evals when using torch.compile
"""
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
model.eval()
"""
PATCHED_TRAINER_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
def get_evaluation_loop_code() -> str:
training_loop = inspect.getsource(Trainer.evaluation_loop)
return training_loop
def check_evaluation_loop_is_patchable() -> bool:
eval_loop = get_evaluation_loop_code()
eval_loop, _ = detab_code(eval_loop)
return ORIGINAL_TRAINER_CODE in eval_loop
def patch_evaluation_loop_for_fsdp2():
"""
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
"""
try:
evaluation_loop = get_evaluation_loop_code()
except OSError:
return
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
evaluation_loop
)
evaluation_loop, _ = detab_code(evaluation_loop)
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
return
evaluation_loop = evaluation_loop.replace(
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
)
evaluation_loop = evaluation_loop.replace(
"def evaluation_loop(",
"def _fixed_evaluation_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in evaluation_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer.evaluation_loop = ( # pylint: disable=protected-access
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -0,0 +1,87 @@
"""
Monkey patch to fix transformers.modeling_flash_attention_utils.
see https://github.com/huggingface/transformers/pull/39653/files
"""
import sys
import torch
def _prepare_from_posids(query, key, value, position_ids):
"""
This function returns necessary arguments to call `flash_attn_varlen_func`.
All three query, key, value states will be flattened.
Cumulative lengths of each examples in the batch will be extracted from position_ids.
NOTE: ideally cumulative lengths should be prepared at the data collator stage
Arguments:
query (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
position_ids (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
Return:
query (`torch.Tensor`):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
position_ids = position_ids.flatten()
indices_q = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
cu_seq_lens = torch.cat(
(
indices_q[position_ids == 0],
torch.tensor(
position_ids.size(), device=position_ids.device, dtype=torch.int32
),
)
)
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length = cu_seq_lens.diff().max().item()
return (
query,
key,
value,
indices_q,
(cu_seq_lens, cu_seq_lens),
(max_length, max_length),
)
def patch_prepare_from_posids():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
_prepare_from_posids
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_prepare_from_posids",
_prepare_from_posids,
)

View File

@@ -1,165 +0,0 @@
"""
Module for patching transformers Trainer loss calculation to use nanmean.
This is needed for context parallelism since chunks of the input sequences may be fully
masked and return NaNs in the loss calculation.
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
the other evaluation_loop patch because we can't patch the same code twice without
raising an OSError.
"""
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
ORIGINAL_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
}
PATCHED_EVAL_CODE = {
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_FSDP2_CODE = """
model.eval()
"""
PATCHED_FSDP2_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
def check_evaluation_loop_is_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
# pylint: disable=protected-access
def patch_evaluation_loop(patch_fsdp2: bool):
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
LOG.info("Trainer.evaluation_loop already patched")
return
# Check if the patterns exist
try:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
except OSError:
return
Trainer.evaluation = evaluation_loop_source
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
# Apply the nanmean patches
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
)
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
)
# Apply FSDP2 eval guard patch if needed
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
)
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
# Rename the function to avoid conflicts
evaluation_loop_source = evaluation_loop_source.replace(
"def evaluation_loop(",
"def axolotl_evaluation_loop(",
1,
)
# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in evaluation_loop_source:
items_to_import.append(item)
# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
Trainer.evaluation_loop = (
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
)
def check_maybe_log_save_evaluate_is_patchable() -> bool:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
return ORIGINAL_MAYBE_CODE in maybe_log_source
# pylint: disable=protected-access
def patch_maybe_log_save_evaluate():
"""Patch the _maybe_log_save_evaluate method."""
# Check if already patched
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
LOG.info("Trainer._maybe_log_save_evaluate already patched")
return
# Check if the patterns exist
try:
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
except OSError:
return
Trainer._original_maybe_log_save_evaluate = maybe_log_source
maybe_log_source, _ = detab_code(maybe_log_source)
# Apply the patch
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
# Rename the function to avoid conflicts
maybe_log_source = maybe_log_source.replace(
"def _maybe_log_save_evaluate(",
"def axolotl_maybe_log_save_evaluate(",
1,
)
# Get the module for necessary imports
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# Import necessary items from the module
items_to_import = []
for item in dir(module):
if item in maybe_log_source:
items_to_import.append(item)
# Execute the imports and patched method
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821

View File

@@ -6,7 +6,7 @@ from typing import Optional
from PIL import Image, ImageOps
from PIL.Image import Resampling
from torch import Tensor, zeros_like
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
from transformers import ProcessorMixin, VoxtralProcessor
from transformers.image_utils import load_image
from axolotl.utils.dict import remove_none_values
@@ -138,7 +138,7 @@ class ProcessingStrategy:
image_key = key
break
# if the image key exists, add the image to the first user message
# if the image key exists, add the image to the first message
if image_key is not None and processed_example[image_key] is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
@@ -179,34 +179,26 @@ class ProcessingStrategy:
# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
msg_ind_to_add = None
ind_to_add = None
first_user_idx = None
for msg_idx, msg_content in enumerate(processed_example["messages"]):
if first_user_idx is None and msg_content["role"] == "user":
first_user_idx = msg_idx
for i, content in enumerate(
processed_example["messages"][msg_idx]["content"]
for i, content in enumerate(
processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
msg_ind_to_add = msg_idx
ind_to_add = i
break
ind_to_add = i
break
# If an image type is found, add the image to that index
if ind_to_add is not None and msg_ind_to_add is not None:
processed_example["messages"][msg_ind_to_add]["content"][
ind_to_add
]["image"] = image_value
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
else:
# if no image type is found, add it to end of the first user message
if first_user_idx is None:
first_user_idx = 0
processed_example["messages"][first_user_idx]["content"].append(
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
@@ -403,24 +395,6 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
return labels
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for SmolVLM2"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<image>" # nosec
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index(self.image_token)
]
def get_processing_strategy(
processor: ProcessorMixin,
chat_template,
@@ -428,43 +402,32 @@ def get_processing_strategy(
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
processing_kwargs = {
"processor": processor,
"chat_template": chat_template,
"image_size": image_size,
"image_resize_algorithm": image_resize_algorithm,
}
if chat_template_type in [None, "tokenizer_default"] and hasattr(
processor.tokenizer, "chat_template"
):
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
if chat_template_type == "qwen2_vl":
return Qwen2VLProcessingStrategy(
**processing_kwargs,
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
**processing_kwargs,
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type == "gemma3n":
return Gemma3nProcessingStrategy(
**processing_kwargs,
processor, chat_template, image_size, image_resize_algorithm
)
if chat_template_type in [
"llama3_2_vision",
"llama4",
"llava",
"mistral_v7_tekken",
"pixtral",
]:
return ProcessingStrategy(
processor, chat_template, image_size, image_resize_algorithm
)
if isinstance(processor, VoxtralProcessor):
return VoxtralProcessingStrategy(
**processing_kwargs,
processor, chat_template, image_size, image_resize_algorithm
)
if isinstance(processor, SmolVLMProcessor):
return SmolVLM2ProcessingStrategy(
**processing_kwargs,
)
# llama3_2_vision, llama4, llava
# mistral_v7_tekken, pixtral, lfm2vl
return ProcessingStrategy(
**processing_kwargs,
)
raise ValueError(f"Unsupported chat template type: {chat_template_type}")

View File

@@ -41,9 +41,7 @@ class ChatTemplatePrompter(Prompter):
field_messages: str = "messages",
field_system: str = "system",
field_tools: str = "tools",
field_thinking: str = "reasoning_content",
roles: dict[str, list[str]] | None = None,
template_thinking_key: str | None = "reasoning_content",
chat_template_kwargs: dict[str, Any] | None = None,
drop_system_message: bool = False,
):
@@ -52,9 +50,8 @@ class ChatTemplatePrompter(Prompter):
message_property_mappings = {
"role": "role",
"content": "content",
"reasoning_content": "reasoning_content",
}
if template_thinking_key and field_thinking:
message_property_mappings[template_thinking_key] = field_thinking
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
@@ -77,12 +74,10 @@ class ChatTemplatePrompter(Prompter):
self.field_messages = field_messages
self.field_system = field_system
self.field_tools = field_tools
self.field_thinking = field_thinking
self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template
self.chat_template_kwargs = chat_template_kwargs or {}
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
self.max_length = max_length
self.drop_system_message = drop_system_message
@@ -129,21 +124,13 @@ class ChatTemplatePrompter(Prompter):
images=images,
return_tensors="pt",
)
if hasattr(batch, "to_dict"):
batch = batch.to_dict()
else:
batch = dict(batch)
# workaround since processor works in batches instead of single examples
out = {}
for k, val in batch.items():
if hasattr(val, "tolist"):
out[k] = (
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
)
if k in ["pixel_values"]:
batch[k] = val.tolist()
else:
out[k] = val
return out
batch[k] = val.squeeze().tolist()
return batch
return self.tokenizer.apply_chat_template(
conversation,
@@ -441,13 +428,10 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = dict(tokenized_res)
tokenized_prompt = tokenized_res
if not self.train_on_inputs:
if isinstance(prompt_ids, dict):
user_prompt_len = len(prompt_ids["input_ids"])
else:
user_prompt_len = len(prompt_ids)
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
@@ -758,9 +742,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
# get the thinking content
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
transformed_message[self.prompter.template_thinking_key] = (
thinking_content.strip()
)
transformed_message["reasoning_content"] = thinking_content.strip()
# take remainder of the content
# strip whitespace from beginning of the remainder (thinking tokens)
@@ -971,10 +953,6 @@ class StrategyLoader:
None,
),
"field_messages": dataset_config.get("field_messages", "messages"),
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
"template_thinking_key": dataset_config.get(
"template_thinking_key", "reasoning_content"
),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.

View File

@@ -72,10 +72,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
builder_kwargs["message_field_training"] = message_field_training
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
def format_message(x):
return x
format_message = (
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
)
if chat_template == "chatml":
from axolotl.core.chat.format.chatml import format_message # noqa F811
if chat_template.startswith("llama3"):

View File

@@ -4,14 +4,11 @@ from __future__ import annotations
import importlib
import inspect
import json
import os
import shutil
import signal
import sys
import typing
import weakref
from collections import OrderedDict
from contextlib import ExitStack
from pathlib import Path
from typing import Any, Dict
@@ -41,7 +38,6 @@ from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer
try:
@@ -50,7 +46,7 @@ except ImportError:
BetterTransformer = None
if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
@@ -128,6 +124,32 @@ def setup_reference_model(
return model_ref
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
"""
Determine the checkpoint to resume from based on configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Path to the checkpoint to resume from, or `None` if not resuming.
"""
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
return cfg.resume_from_checkpoint
def setup_signal_handler(
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
):
@@ -196,7 +218,6 @@ def execute_training(
ring_attn_func=cfg.ring_attn_func,
heads_k_stride=cfg.heads_k_stride,
gather_outputs=cfg.rl is RLType.GRPO,
device_mesh=trainer.accelerator.torch_device_mesh,
)
)
@@ -253,60 +274,19 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
if ( # pylint: disable=too-many-nested-blocks
trainer.is_fsdp_enabled or cfg.fsdp_config
):
if trainer.is_fsdp_enabled:
if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type
else:
state_dict_type = cfg.fsdp_config.state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT
trainer.save_model(cfg.output_dir)
if state_dict_type == "SHARDED_STATE_DICT":
LOG.info(
"The final model was saved with a sharded state dict. Please ensure you merge "
"the sharded weights with `merge-sharded-fsdp-weights`."
)
checkpoint_dir = determine_last_checkpoint(cfg, update=False)
if (
not (Path(cfg.output_dir) / "model.safetensors.index.json").exists()
and checkpoint_dir
):
# import here to prevent circular import
from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
merged_path = str(Path(cfg.output_dir) / "merged")
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=merged_path,
safe_serialization=True,
)
trainer.accelerator.wait_for_everyone()
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
# cleanup the FSDP prefix in the model config.json
if trainer.accelerator.is_main_process:
with open(
Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
) as config_file_io:
# read the model config as an OrderedDict
config = json.load(config_file_io, object_pairs_hook=OrderedDict)
config["architectures"] = [
name.lstrip("FSDP") for name in config["architectures"]
]
# write the updated model config back
with open(
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
) as config_file_io:
json.dump(config, config_file_io, indent=2)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
@@ -583,13 +563,9 @@ def train(
setup_model_card(cfg)
# Execute the training
resume_from_checkpoint = determine_last_checkpoint(cfg)
resume_from_checkpoint = determine_resume_checkpoint(cfg)
execute_training(cfg, trainer, resume_from_checkpoint)
# clear cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
create_model_card(cfg, trainer)

View File

@@ -1,216 +0,0 @@
# Axolotl TUI (Terminal User Interface)
A comprehensive Terminal User Interface for Axolotl, providing an interactive way to manage configurations, training jobs, datasets, models, and system monitoring.
## Features
### 🏠 Main Dashboard
- **Welcome Screen**: Central hub with quick access to all features
- **Keyboard Navigation**: Efficient navigation with keyboard shortcuts
- **Screen Management**: Easy switching between different functional areas
### 📝 Configuration Management
- **YAML Editor**: Syntax-highlighted editor for Axolotl configurations
- **Real-time Validation**: Instant config validation with detailed error reporting
- **File Browser**: Navigate and select configuration files
- **Template Loading**: Load example configurations
- **Remote Config Support**: Load configurations from URLs
**Key Shortcuts:**
- `Ctrl+N`: New configuration
- `Ctrl+S`: Save configuration
- `Ctrl+V`: Validate configuration
- `Ctrl+E`: Toggle edit mode
### 🚀 Training Management
- **Job Launcher**: Start training with different launchers (accelerate, torchrun)
- **Real-time Monitoring**: Live training progress and metrics
- **Loss Visualization**: Sparkline charts for loss curves
- **Job Control**: Start, stop, resume, and manage multiple training jobs
- **Log Streaming**: Real-time log viewing and filtering
**Key Shortcuts:**
- `Ctrl+T`: New training job
- `Ctrl+R`: Resume training
- `Ctrl+X`: Stop training
- `R`: Refresh status
### 📊 Dataset Management
- **Dataset Browser**: Explore local and remote datasets
- **Preview & Statistics**: View dataset samples and metadata
- **Preprocessing**: Run dataset preprocessing with progress tracking
- **HuggingFace Integration**: Download and manage HF datasets
- **Format Detection**: Automatic dataset format recognition
**Key Shortcuts:**
- `Ctrl+P`: Preprocess dataset
- `Ctrl+V`: Preview dataset
- `Ctrl+I`: Dataset information
- `R`: Refresh dataset list
### 🤖 Model Management
- **Model Discovery**: Automatically find trained models
- **LoRA Operations**: Merge LoRA adapters with base models
- **Quantization**: Quantize models for deployment
- **Evaluation**: Run model evaluation benchmarks
- **Storage Info**: View model sizes and storage details
**Key Shortcuts:**
- `Ctrl+M`: Merge LoRA
- `Ctrl+Q`: Quantize model
- `Ctrl+E`: Evaluate model
- `R`: Refresh model list
### 💬 Inference & Testing
- **Interactive Chat**: Chat interface for model testing
- **Parameter Tuning**: Adjust inference parameters (temperature, top-p, max tokens)
- **Model Loading**: Load and switch between different models
- **Chat History**: Save and load conversation history
- **Gradio Integration**: Launch Gradio web interface
**Key Shortcuts:**
- `Ctrl+Enter`: Send message
- `Ctrl+C`: Clear chat
- `Ctrl+L`: Load model
- `Ctrl+S`: Save chat
### 📈 System Monitoring
- **Resource Monitoring**: Real-time CPU, GPU, and memory usage
- **Process Management**: View and manage running processes
- **Performance Graphs**: Historical usage charts with sparklines
- **GPU Information**: Detailed GPU status and memory usage
- **Temperature Monitoring**: System temperature tracking
**Key Shortcuts:**
- `R`: Refresh metrics
- `Ctrl+K`: Kill selected process
## Installation
### Dependencies
```bash
pip install textual==1.0.0 rich==14.1.0
```
### Launch TUI
```bash
# From command line
python -m axolotl.cli.main tui
# From Python code
from axolotl.tui.app import run
run()
```
## Architecture
### Screen Structure
```
AxolotlTUI (Main App)
├── WelcomeScreen (Dashboard)
├── ConfigScreen (Configuration Management)
├── TrainingScreen (Training Management)
├── DatasetScreen (Dataset Management)
├── ModelScreen (Model Management)
├── InferenceScreen (Inference & Testing)
└── MonitorScreen (System Monitoring)
```
### Key Components
- **BaseScreen**: Common functionality for all screens
- **Screen Navigation**: Stack-based screen management
- **Event Handling**: Reactive UI updates
- **Background Tasks**: Non-blocking operations
- **State Management**: Shared application state
### Integration Points
- **CLI Commands**: Seamless integration with existing axolotl CLI
- **Configuration System**: Uses axolotl's native config loading
- **Training Pipeline**: Integrates with axolotl training functions
- **Model Loading**: Compatible with axolotl model management
## Usage Examples
### 1. Creating a New Configuration
1. Launch TUI: `python -m axolotl.cli.main tui`
2. Select "Configuration Management" or press `C`
3. Press `Ctrl+N` for new configuration
4. Edit the template configuration
5. Press `Ctrl+V` to validate
6. Press `Ctrl+S` to save
### 2. Starting a Training Job
1. Navigate to "Training Management" or press `T`
2. Press `Ctrl+T` for new training job
3. Select configuration file and launcher
4. Monitor progress in real-time
5. View loss curves and logs
### 3. Interactive Model Testing
1. Go to "Inference & Testing" or press `I`
2. Load a trained model with `Ctrl+L`
3. Adjust inference parameters as needed
4. Start chatting with the model
5. Save conversation with `Ctrl+S`
## Navigation
### Global Shortcuts
- `Ctrl+Q`: Quit application
- `Escape`: Go back/close current screen
- `Tab`: Navigate between UI elements
- `Enter`: Select/activate element
- `Space`: Toggle switches/checkboxes
### Screen Shortcuts
Each screen has specific shortcuts displayed in the footer. Common patterns:
- `Ctrl+[Letter]`: Primary actions
- `R`: Refresh/reload
- `F1-F12`: Function keys for advanced features
## Customization
### Themes
The TUI uses Textual's theming system and can be customized by modifying the CSS in each screen class.
### Adding New Screens
1. Create a new screen class inheriting from `BaseScreen`
2. Implement the `compose()` method for UI layout
3. Add event handlers for user interactions
4. Register the screen in the main app navigation
### Extending Functionality
- Add new widgets to existing screens
- Implement custom data visualization
- Integrate with external tools and APIs
- Add new keyboard shortcuts
## Troubleshooting
### Common Issues
1. **Import Errors**: Ensure textual and rich are installed
2. **Permission Errors**: Check file system permissions for config directories
3. **GPU Monitoring**: Install pynvml for GPU monitoring features
4. **Config Validation**: Ensure axolotl dependencies are properly installed
### Debug Mode
Launch with debug logging:
```bash
TEXTUAL_LOG=DEBUG python -m axolotl.cli.main tui
```
### Performance
- Use `Ctrl+\` to open Textual's debug console
- Monitor memory usage with the system monitor
- Disable auto-refresh for better performance on slower systems
## Contributing
The TUI is designed to be extensible. Contributions are welcome for:
- New screen implementations
- Enhanced visualizations
- Better keyboard navigation
- Additional integrations
- Performance improvements
See the main Axolotl repository for contribution guidelines.

View File

@@ -1 +0,0 @@
"""Axolotl Terminal User Interface (TUI)."""

View File

@@ -1,180 +0,0 @@
"""Main TUI application for Axolotl."""
from textual import on
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container
from textual.screen import Screen
from textual.widgets import Button, Footer, Header, Static
from axolotl.tui.screens.config import ConfigScreen
from axolotl.tui.screens.datasets import DatasetScreen
from axolotl.tui.screens.inference import InferenceScreen
from axolotl.tui.screens.models import ModelScreen
from axolotl.tui.screens.monitor import MonitorScreen
from axolotl.tui.screens.training import TrainingScreen
class WelcomeScreen(Screen):
"""Welcome screen with main menu."""
BINDINGS = [
Binding("q", "quit", "Quit"),
Binding("c", "config", "Configuration"),
Binding("t", "training", "Training"),
Binding("d", "datasets", "Datasets"),
Binding("m", "models", "Models"),
Binding("i", "inference", "Inference"),
Binding("s", "monitor", "System Monitor"),
]
def compose(self) -> ComposeResult:
"""Compose the welcome screen."""
yield Header()
yield Container(
Static("🦾 Axolotl TUI", classes="title"),
Static(
"A Terminal User Interface for fine-tuning LLMs", classes="subtitle"
),
Container(
Button("Configuration Management [C]", id="config", variant="primary"),
Button("Training Management [T]", id="training", variant="primary"),
Button("Dataset Management [D]", id="datasets", variant="primary"),
Button("Model Management [M]", id="models", variant="primary"),
Button("Inference & Testing [I]", id="inference", variant="primary"),
Button("System Monitor [S]", id="monitor", variant="primary"),
classes="menu-container",
),
classes="welcome-container",
)
yield Footer()
def action_quit(self) -> None:
"""Quit the application."""
self.app.exit()
def action_config(self) -> None:
"""Navigate to config screen."""
self.app.push_screen(ConfigScreen())
def action_training(self) -> None:
"""Navigate to training screen."""
self.app.push_screen(TrainingScreen())
def action_datasets(self) -> None:
"""Navigate to datasets screen."""
self.app.push_screen(DatasetScreen())
def action_models(self) -> None:
"""Navigate to models screen."""
self.app.push_screen(ModelScreen())
def action_inference(self) -> None:
"""Navigate to inference screen."""
self.app.push_screen(InferenceScreen())
def action_monitor(self) -> None:
"""Navigate to monitor screen."""
self.app.push_screen(MonitorScreen())
@on(Button.Pressed, "#config")
def on_config_pressed(self) -> None:
"""Handle config button press."""
self.action_config()
@on(Button.Pressed, "#training")
def on_training_pressed(self) -> None:
"""Handle training button press."""
self.action_training()
@on(Button.Pressed, "#datasets")
def on_datasets_pressed(self) -> None:
"""Handle datasets button press."""
self.action_datasets()
@on(Button.Pressed, "#models")
def on_models_pressed(self) -> None:
"""Handle models button press."""
self.action_models()
@on(Button.Pressed, "#inference")
def on_inference_pressed(self) -> None:
"""Handle inference button press."""
self.action_inference()
@on(Button.Pressed, "#monitor")
def on_monitor_pressed(self) -> None:
"""Handle monitor button press."""
self.action_monitor()
class AxolotlTUI(App):
"""Main Axolotl TUI Application."""
CSS = """
.title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.subtitle {
text-align: center;
padding: 1;
color: $text-muted;
}
.welcome-container {
align: center middle;
height: 100%;
width: 100%;
}
.menu-container {
layout: vertical;
align: center middle;
padding: 2;
width: auto;
height: auto;
}
.menu-container Button {
width: 35;
margin: 1;
}
WelcomeScreen {
align: center middle;
}
"""
BINDINGS = [
Binding("ctrl+q", "quit", "Quit", priority=True),
Binding("escape", "back", "Back", priority=True),
]
def on_mount(self) -> None:
"""Called when the app is mounted."""
self.title = "Axolotl TUI"
self.sub_title = "Fine-tuning LLMs made easy"
self.push_screen(WelcomeScreen())
def action_quit(self) -> None:
"""Quit the application."""
self.exit()
def action_back(self) -> None:
"""Go back to previous screen."""
if len(self.screen_stack) > 1:
self.pop_screen()
def run():
"""Run the Axolotl TUI application."""
app = AxolotlTUI()
app.run()
if __name__ == "__main__":
run()

View File

@@ -1 +0,0 @@
"""TUI dialogs for Axolotl."""

View File

@@ -1,112 +0,0 @@
"""Training dialogs for Axolotl TUI."""
from pathlib import Path
from textual import on
from textual.app import ComposeResult
from textual.containers import Container
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Select, Static
class NewTrainingDialog(ModalScreen):
"""Dialog for starting a new training job."""
CSS = """
NewTrainingDialog {
align: center middle;
}
.dialog-container {
background: $surface;
border: thick $primary;
padding: 2;
width: 60;
height: auto;
}
.dialog-title {
text-align: center;
text-style: bold;
padding: 1;
color: $primary;
}
.form-field {
margin: 1 0;
}
.form-label {
margin: 0 0 1 0;
color: $text-muted;
}
.button-container {
layout: horizontal;
align: center middle;
margin: 2 0 0 0;
}
.button-container Button {
margin: 0 1;
}
"""
def compose(self) -> ComposeResult:
"""Compose the dialog."""
yield Container(
Static("Start New Training Job", classes="dialog-title"),
Container(
Label("Configuration File:", classes="form-label"),
Input(
placeholder="Path to config YAML file",
id="config-path",
value="/workspace/configs/",
),
classes="form-field",
),
Container(
Label("Launcher:", classes="form-label"),
Select(
[
("accelerate", "Accelerate (Recommended)"),
("torchrun", "TorchRun"),
("deepspeed", "DeepSpeed"),
],
id="launcher",
value="accelerate",
),
classes="form-field",
),
Container(
Button("Start Training", variant="primary", id="start"),
Button("Cancel", variant="default", id="cancel"),
classes="button-container",
),
classes="dialog-container",
)
@on(Button.Pressed, "#start")
def handle_start(self) -> None:
"""Handle start button press."""
config_input = self.query_one("#config-path", Input)
launcher_select = self.query_one("#launcher", Select)
config_path = config_input.value.strip()
if not config_path:
return
if not Path(config_path).exists():
return
result = {
"config_path": config_path,
"launcher": launcher_select.value,
}
self.dismiss(result)
@on(Button.Pressed, "#cancel")
def handle_cancel(self) -> None:
"""Handle cancel button press."""
self.dismiss(None)

Some files were not shown because too many files have changed in this diff Show More