Compare commits
44 Commits
chat-templ
...
fa-check
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c0fa60220 | ||
|
|
8efdc59796 | ||
|
|
172b08b209 | ||
|
|
3d45620008 | ||
|
|
ce20e838b5 | ||
|
|
d4d84d48af | ||
|
|
9b12c05660 | ||
|
|
686933194e | ||
|
|
d12b461d19 | ||
|
|
d6b81b3683 | ||
|
|
05f1b4b2e8 | ||
|
|
7cfc80ec77 | ||
|
|
0da6a95efa | ||
|
|
2c8497e489 | ||
|
|
f70d4de8c7 | ||
|
|
0ae06d756d | ||
|
|
2974670bf8 | ||
|
|
50f2b94d50 | ||
|
|
eb2c87b525 | ||
|
|
4db7f023c6 | ||
|
|
4273d5cf7e | ||
|
|
c5e5aba547 | ||
|
|
9d5c95db6f | ||
|
|
ca796fb56e | ||
|
|
597953bef0 | ||
|
|
39fbd3b2b5 | ||
|
|
46dfacf255 | ||
|
|
4bce713b39 | ||
|
|
d09290f2f4 | ||
|
|
e442ff22aa | ||
|
|
ba3dba3e4f | ||
|
|
97e86c6d47 | ||
|
|
784f8c0e95 | ||
|
|
e3177c3210 | ||
|
|
70faea331f | ||
|
|
8021c718ce | ||
|
|
42f5e6f9e9 | ||
|
|
ab49d16e34 | ||
|
|
33d094721c | ||
|
|
a54c1be972 | ||
|
|
5691992d34 | ||
|
|
e758343cac | ||
|
|
deac7b18a1 | ||
|
|
10946afae7 |
27
.github/workflows/base.yml
vendored
27
.github/workflows/base.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.6.3
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
@@ -64,9 +64,16 @@ jobs:
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base-nightly"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.11"
|
||||
# pytorch: nightly
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base-nightly"
|
||||
# # "next" is for release candidates of pytorch
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
@@ -122,6 +129,13 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -129,6 +143,13 @@ jobs:
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
23
.github/workflows/main.yml
vendored
23
.github/workflows/main.yml
vendored
@@ -24,12 +24,13 @@ jobs:
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -97,6 +98,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -150,6 +157,18 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@@ -105,7 +105,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
@@ -179,8 +180,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
@@ -23,11 +23,11 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.7
|
||||
rev: v3.3.8
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.0
|
||||
rev: v1.17.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -185,7 +185,6 @@ datasets:
|
||||
| `flash_attention` | `false` | Use flash attention |
|
||||
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
||||
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
||||
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
|
||||
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
||||
| `sdp_attention` | `false` | Use scaled dot product |
|
||||
| `s2_attention` | `false` | Use shifted sparse attention |
|
||||
|
||||
@@ -296,7 +296,6 @@
|
||||
# flash_attention:
|
||||
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# # Whether to use scaled-dot-product attention
|
||||
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
@@ -541,7 +540,6 @@ xformers_attention: ${XFORMERS_ATTENTION}
|
||||
flash_attention: ${FLASH_ATTENTION}
|
||||
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
||||
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
||||
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
|
||||
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
||||
sdp_attention: ${SDP_ATTENTION}
|
||||
s2_attention: ${S2_ATTENTION}
|
||||
|
||||
10
CITATION.cff
Normal file
10
CITATION.cff
Normal file
@@ -0,0 +1,10 @@
|
||||
cff-version: 1.2.0
|
||||
type: software
|
||||
title: "Axolotl: Post-Training for AI Models"
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "Axolotl maintainers and contributors"
|
||||
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
|
||||
url: "https://axolotl.ai/"
|
||||
license: Apache-2.0
|
||||
date-released: "2023-05-30"
|
||||
33
README.md
33
README.md
@@ -25,17 +25,28 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
|
||||
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||
|
||||
</details>
|
||||
|
||||
## ✨ Overview
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
@@ -138,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
|
||||
|
||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||
|
||||
## 📝 Citing Axolotl
|
||||
|
||||
If you use Axolotl in your research or projects, please cite it as follows:
|
||||
|
||||
```bibtex
|
||||
@software{axolotl,
|
||||
title = {Axolotl: Post-Training for AI Models},
|
||||
author = {{Axolotl maintainers and contributors}},
|
||||
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
||||
license = {Apache-2.0},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 📜 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
@@ -274,6 +274,7 @@ website:
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/optimizers.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
|
||||
@@ -212,10 +212,11 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: ...
|
||||
- path: Nanobit/text-tools-2k-test
|
||||
type: chat_template
|
||||
# field_tools: tools # default is `tools`
|
||||
```
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# N-D Parallelism
|
||||
---
|
||||
title: "N-D Parallelism (Beta)"
|
||||
---
|
||||
|
||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||
|
||||
@@ -71,6 +73,10 @@ Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size
|
||||
|
||||
## Examples
|
||||
|
||||
::: {.callout-tip}
|
||||
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
|
||||
:::
|
||||
|
||||
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
||||
- You want FSDP within each node and DDP across nodes.
|
||||
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
||||
@@ -95,7 +101,7 @@ This matrix describes how different parallelism methods can be combined in Axolo
|
||||
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
||||
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
||||
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
|
||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
|
||||
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
||||
|
||||
- `tp_size` refers to `tensor_parallel_size`
|
||||
|
||||
129
docs/optimizers.qmd
Normal file
129
docs/optimizers.qmd
Normal file
@@ -0,0 +1,129 @@
|
||||
---
|
||||
title: Optimizers
|
||||
description: Configuring optimizers
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
|
||||
|
||||
Here is a list of optimizers supported by transformers as of `v4.54.0`:
|
||||
|
||||
- `adamw_torch`
|
||||
- `adamw_torch_fused`
|
||||
- `adamw_torch_xla`
|
||||
- `adamw_torch_npu_fused`
|
||||
- `adamw_apex_fused`
|
||||
- `adafactor`
|
||||
- `adamw_anyprecision`
|
||||
- `adamw_torch_4bit`
|
||||
- `adamw_torch_8bit`
|
||||
- `ademamix`
|
||||
- `sgd`
|
||||
- `adagrad`
|
||||
- `adamw_bnb_8bit`
|
||||
- `adamw_8bit` # alias for adamw_bnb_8bit
|
||||
- `ademamix_8bit`
|
||||
- `lion_8bit`
|
||||
- `lion_32bit`
|
||||
- `paged_adamw_32bit`
|
||||
- `paged_adamw_8bit`
|
||||
- `paged_ademamix_32bit`
|
||||
- `paged_ademamix_8bit`
|
||||
- `paged_lion_32bit`
|
||||
- `paged_lion_8bit`
|
||||
- `rmsprop`
|
||||
- `rmsprop_bnb`
|
||||
- `rmsprop_bnb_8bit`
|
||||
- `rmsprop_bnb_32bit`
|
||||
- `galore_adamw`
|
||||
- `galore_adamw_8bit`
|
||||
- `galore_adafactor`
|
||||
- `galore_adamw_layerwise`
|
||||
- `galore_adamw_8bit_layerwise`
|
||||
- `galore_adafactor_layerwise`
|
||||
- `lomo`
|
||||
- `adalomo`
|
||||
- `grokadamw`
|
||||
- `schedule_free_radam`
|
||||
- `schedule_free_adamw`
|
||||
- `schedule_free_sgd`
|
||||
- `apollo_adamw`
|
||||
- `apollo_adamw_layerwise`
|
||||
- `stable_adamw`
|
||||
|
||||
|
||||
## Custom Optimizers
|
||||
|
||||
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
|
||||
|
||||
### optimi_adamw
|
||||
|
||||
```yaml
|
||||
optimizer: optimi_adamw
|
||||
```
|
||||
|
||||
### ao_adamw_4bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_4bit`.
|
||||
|
||||
### ao_adamw_8bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_8bit`.
|
||||
|
||||
### ao_adamw_fp8
|
||||
|
||||
|
||||
```yaml
|
||||
optimizer: ao_adamw_fp8
|
||||
```
|
||||
|
||||
### adopt_adamw
|
||||
|
||||
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
|
||||
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
|
||||
|
||||
```yaml
|
||||
optimizer: adopt_adamw
|
||||
```
|
||||
|
||||
### came_pytorch
|
||||
|
||||
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
|
||||
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
|
||||
|
||||
```yaml
|
||||
optimizer: came_pytorch
|
||||
|
||||
# optional args (defaults below)
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.999
|
||||
adam_beta3: 0.9999
|
||||
adam_epsilon: 1e-30
|
||||
adam_epsilon2: 1e-16
|
||||
```
|
||||
|
||||
### muon
|
||||
|
||||
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
|
||||
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
|
||||
|
||||
```yaml
|
||||
optimizer: muon
|
||||
```
|
||||
|
||||
### dion
|
||||
|
||||
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
||||
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
||||
|
||||
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
|
||||
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
|
||||
Note: Implementation written for PyTorch 2.7+ for DTensor
|
||||
|
||||
```yaml
|
||||
optimizer: dion
|
||||
dion_lr: 0.01
|
||||
dion_momentum: 0.95
|
||||
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
|
||||
```
|
||||
53
examples/arcee/README.md
Normal file
53
examples/arcee/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Finetune ArceeAI's AFM with Axolotl
|
||||
|
||||
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/arcee/afm-4.5b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 7.8GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: arcee-ai/AFM-4.5B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -47,7 +47,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -10,17 +10,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
52
examples/distributed-parallel/README.md
Normal file
52
examples/distributed-parallel/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# ND Parallelism Examples
|
||||
|
||||
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Run the command below:
|
||||
|
||||
```bash
|
||||
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
|
||||
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
|
||||
|
||||
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
|
||||
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
|
||||
```
|
||||
|
||||
## Example Configurations
|
||||
|
||||
### Single Node (8 GPUs)
|
||||
|
||||
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
|
||||
- Uses all 3 parallelism dimensions on a single node
|
||||
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
|
||||
|
||||
```yaml
|
||||
dp_shard_size: 2 # FSDP across 2 GPUs
|
||||
tensor_parallel_size: 2 # TP across 2 GPUs
|
||||
context_parallel_size: 2 # CP across 2 GPUs
|
||||
# Total: 2 × 2 × 2 = 8 GPUs
|
||||
```
|
||||
|
||||
### Multi-Node
|
||||
|
||||
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
|
||||
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
|
||||
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
|
||||
|
||||
```yaml
|
||||
dp_shard_size: 4 # FSDP within each 4-GPU group
|
||||
tensor_parallel_size: 2 # TP within each node
|
||||
dp_replicate_size: 2 # DDP across 2 groups
|
||||
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
|
||||
```
|
||||
|
||||
## Learn More
|
||||
|
||||
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
|
||||
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
|
||||
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 4
|
||||
dp_replicate_size: 2
|
||||
tensor_parallel_size: 2
|
||||
# context_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 2
|
||||
# dp_replicate_size: 1
|
||||
context_parallel_size: 2
|
||||
tensor_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1 # must be 1 when using context parallel
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
@@ -4,17 +4,14 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||
|
||||
74
examples/gpt-oss/README.md
Normal file
74
examples/gpt-oss/README.md
Normal file
@@ -0,0 +1,74 @@
|
||||
# Finetune OpenAI's GPT-OSS with Axolotl
|
||||
|
||||
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||
|
||||
```bash
|
||||
# LoRA SFT linear layers (1x48GB @ ~44GiB)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
|
||||
|
||||
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
|
||||
|
||||
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
|
||||
```
|
||||
|
||||
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
|
||||
|
||||
### Training 120B
|
||||
|
||||
On 8xH100s
|
||||
|
||||
```bash
|
||||
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
```
|
||||
|
||||
### Tool use
|
||||
|
||||
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
||||
|
||||
Here is an example dataset config:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: Nanobit/text-tools-2k-test
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
|
||||
|
||||
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
67
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
67
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
|
||||
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
|
||||
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
|
||||
|
||||
use_kernels: false
|
||||
|
||||
dp_shard_size: 16 # requires 2x8xH100 nodes
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
cpu_ram_efficient_loading: true
|
||||
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: false
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
# choose the zero3 configuration that best fits your system capabilities
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
|
||||
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
|
||||
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`
|
||||
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: false
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
|
||||
# TODO: not supported for now, see peft#2710
|
||||
#lora_target_parameters: # target the experts in the last two layers
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
@@ -45,7 +45,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -49,7 +49,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -8,17 +8,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -27,7 +27,6 @@ sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
66
examples/slurm/README.md
Normal file
66
examples/slurm/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# SLURM Multi-Node Training
|
||||
|
||||
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Access to a SLURM cluster with GPU nodes
|
||||
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Usage
|
||||
|
||||
### Standard SLURM Clusters
|
||||
|
||||
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
|
||||
2. Place your Axolotl config file (`train.yaml`) in the same directory.
|
||||
3. Set the appropriate environment variables for the job:
|
||||
```bash
|
||||
export HF_TOKEN="your-huggingface-token"
|
||||
|
||||
# metric tracking
|
||||
# export WANDB_API_KEY="your-wandb-api-key"
|
||||
# ...
|
||||
```
|
||||
4. Submit the job:
|
||||
```bash
|
||||
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
|
||||
```
|
||||
|
||||
Where:
|
||||
- `NUM_NODES`: Number of nodes to use
|
||||
- `NUM_TRAINERS`: GPUs per node (typically 8)
|
||||
- `PRIMARY_ADDR`: Hostname/IP of the master node
|
||||
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
|
||||
|
||||
5. (Optional) Run other slurm commands:
|
||||
```bash
|
||||
# check job info
|
||||
scontrol show job axolotl-cli
|
||||
|
||||
# check job queue
|
||||
squeue
|
||||
|
||||
# check cluster status
|
||||
sinfo
|
||||
```
|
||||
|
||||
### RunPod Instant Clusters
|
||||
|
||||
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
|
||||
|
||||
1. **Deploy a SLURM Cluster**:
|
||||
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
|
||||
- Click "Create a Cluster"
|
||||
- Choose your GPU type, node count, and region
|
||||
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
|
||||
- Deploy the cluster
|
||||
|
||||
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
|
||||
|
||||
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
|
||||
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)
|
||||
20
examples/slurm/axolotl.slurm
Normal file
20
examples/slurm/axolotl.slurm
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
|
||||
# export HF_TOKEN="..."
|
||||
# export WANDB_API_KEY="..."
|
||||
#
|
||||
|
||||
# ---------- SBATCH commands ---------- #
|
||||
#SBATCH --job-name=axolotl-slurm-multinode
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --nodes=$NUM_NODES
|
||||
#SBATCH --gpus-per-task=8
|
||||
#SBATCH --cpus-per-task=128
|
||||
|
||||
export TORCH_DIST_INIT_BARRIER=0
|
||||
|
||||
srun axolotl preprocess train.yaml
|
||||
|
||||
srun axolotl train train.yaml --launcher torchrun -- \
|
||||
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
|
||||
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"
|
||||
@@ -6,17 +6,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Please install the below.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.46.0
|
||||
triton>=3.0.0
|
||||
bitsandbytes==0.46.1
|
||||
# triton 3.4.0 is not compatible with CCE
|
||||
triton>=3.0.0,<3.4.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
@@ -12,19 +13,21 @@ liger-kernel==0.6.1
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft==0.16.0
|
||||
transformers==4.54.1
|
||||
peft==0.17.0
|
||||
transformers @ git+https://github.com/vasqu/transformers@fix-fa-integration
|
||||
tokenizers>=0.21.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.20.0
|
||||
trl==0.21.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.23.3
|
||||
gradio==5.41.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic==2.10.6
|
||||
@@ -66,6 +69,6 @@ torchao==0.12.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
|
||||
@@ -44,8 +44,13 @@ add_keys_to_authorized() {
|
||||
chmod 700 -R ~/.ssh
|
||||
}
|
||||
|
||||
# Set SSH port
|
||||
if [ ! -z "$SSH_PORT" ]; then
|
||||
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
|
||||
fi
|
||||
|
||||
if [[ $PUBLIC_KEY ]]; then
|
||||
# runpod
|
||||
# runpod, prime intellect
|
||||
add_keys_to_authorized "$PUBLIC_KEY"
|
||||
# Start the SSH service in the background
|
||||
service ssh start
|
||||
@@ -76,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||
fi
|
||||
|
||||
# start the runpod slurm init
|
||||
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
|
||||
|
||||
if [[ -f "$SLURM_INIT" ]]; then
|
||||
echo "[entrypoint] running $SLURM_INIT..."
|
||||
bash "$SLURM_INIT"
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
exec "$@"
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.12.0.dev"
|
||||
__version__ = "0.13.0.dev"
|
||||
|
||||
@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def plugin_set_cfg(cfg: DictDefault):
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.cfg = cfg
|
||||
# now that we have the finalized cfg, register the plugins individually
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def load_cfg(
|
||||
|
||||
@@ -123,9 +123,10 @@ def train(
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# Process each configuration
|
||||
for cfg_file in generate_config_files(config, sweep):
|
||||
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||
try:
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
||||
use_exec = is_group is not True
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
@@ -64,10 +65,20 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
"""Generate list of configuration files to process."""
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||
"""
|
||||
Generate list of configuration files to process.
|
||||
|
||||
Args:
|
||||
config: Base configuration file
|
||||
sweep: Sweep configuration file
|
||||
|
||||
Yields:
|
||||
Tuple of configuration file name and whether this is a group of configurations
|
||||
"""
|
||||
|
||||
if not sweep:
|
||||
yield config
|
||||
yield config, False
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
@@ -78,6 +89,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
is_group = len(permutations) > 1
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
@@ -88,7 +100,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name
|
||||
yield temp_file.name, is_group
|
||||
|
||||
|
||||
def launch_training(
|
||||
@@ -97,6 +109,7 @@ def launch_training(
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -105,11 +118,14 @@ def launch_training(
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
elif launcher is None:
|
||||
# handle ray train launch
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
@@ -136,7 +152,10 @@ def _launch_cloud_training(
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -161,11 +180,20 @@ def _launch_accelerate_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -178,7 +206,13 @@ def _launch_torchrun_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import trl
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
@@ -42,13 +40,17 @@ def do_vllm_serve(
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
@@ -81,63 +83,3 @@ def do_vllm_serve(
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
|
||||
def patch_vllm_worker():
|
||||
from multiprocessing.connection import Connection
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
def llm_worker(
|
||||
script_args: AxolotlScriptArguments,
|
||||
data_parallel_rank: int,
|
||||
master_port: int,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
# Set required environment variables for DP to work with vLLM
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
revision=script_args.revision,
|
||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
enable_reasoning=script_args.enable_reasoning,
|
||||
reasoning_parser=script_args.reasoning_parser,
|
||||
)
|
||||
|
||||
# Send ready signal to parent process
|
||||
connection.send({"status": "ready"})
|
||||
|
||||
while True:
|
||||
# Wait for commands from the parent process
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
llm.collective_rpc(method="close_communicator")
|
||||
break
|
||||
|
||||
# Handle commands
|
||||
if command["type"] in ["call", "fire_and_forget"]:
|
||||
method_name = command["method"]
|
||||
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
trl.scripts.vllm_serve.llm_worker = llm_worker
|
||||
|
||||
@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -24,12 +24,10 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
@@ -40,6 +38,7 @@ from axolotl.utils.callbacks import (
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -267,27 +266,24 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "dion":
|
||||
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
|
||||
DionOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = DionOptimizerFactory
|
||||
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
if device_mesh is not None:
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
@@ -433,30 +429,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
partial_state = PartialState()
|
||||
has_pc_attr = (
|
||||
hasattr(partial_state, "parallelism_config")
|
||||
and partial_state.parallelism_config
|
||||
)
|
||||
has_pc_key = (
|
||||
"parallelism_config"
|
||||
in partial_state._shared_state # pylint: disable=protected-access
|
||||
and partial_state._shared_state[ # pylint: disable=protected-access
|
||||
"parallelism_config"
|
||||
]
|
||||
)
|
||||
use_configured_state = has_pc_attr or has_pc_key
|
||||
if self.cfg.accelerator_config:
|
||||
use_configured_state = self.cfg.accelerator_config.pop(
|
||||
"use_configured_state", use_configured_state
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||
**self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state,
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
@@ -516,10 +494,20 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"seed",
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
|
||||
arg_map = {
|
||||
"dion_learning_rate": "dion_lr",
|
||||
}
|
||||
for kwarg, cfg_arg in arg_map.items():
|
||||
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
|
||||
|
||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||
training_args_kwargs["average_tokens_across_devices"] = False
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -136,6 +137,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlRewardTrainer
|
||||
if self.cfg.process_reward_model:
|
||||
return AxolotlPRMTrainer
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
return trainer_cls
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -350,7 +363,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||
self.cfg.sequence_len / multiple
|
||||
)
|
||||
else:
|
||||
elif self.cfg.pad_to_sequence_len is None:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
|
||||
@@ -10,8 +10,11 @@ from functools import partial, wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
import safetensors
|
||||
import torch
|
||||
from accelerate.state import AcceleratorState
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -19,8 +22,10 @@ from torch.utils.data import (
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -515,7 +520,18 @@ class AxolotlTrainer(
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
||||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
use_configured_state = accelerator_config.get("use_configured_state", False)
|
||||
if not use_configured_state:
|
||||
AcceleratorState._reset_state( # pylint: disable=protected-access
|
||||
reset_partial_state=True
|
||||
)
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
@@ -524,8 +540,6 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
return res
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
@@ -567,10 +581,10 @@ class AxolotlTrainer(
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_memory_active"] = active
|
||||
logs["memory/max_memory_allocated"] = allocated
|
||||
logs["memory/device_memory_reserved"] = reserved
|
||||
except (ValueError, FileNotFoundError):
|
||||
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
@@ -590,3 +604,64 @@ class AxolotlTrainer(
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
else (PreTrainedModel, PeftModel)
|
||||
)
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
):
|
||||
self.accelerator.unwrap_model(
|
||||
self.model, keep_torch_compile=False
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||
)
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
from accelerate import PartialState
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
@@ -18,3 +19,15 @@ class DistributedParallelMixin(Trainer):
|
||||
):
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
super()._save(output_dir, state_dict=state_dict)
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
super().create_accelerator_and_postprocess()
|
||||
if (
|
||||
self.accelerator.distributed_type == "FSDP"
|
||||
and self.accelerator.state.fsdp_plugin is None
|
||||
):
|
||||
# pylint: disable=protected-access
|
||||
# handle Context Parallelism without FSDP
|
||||
self.accelerator.state.distributed_type = "MULTI_GPU"
|
||||
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
||||
PartialState().distributed_type = "MULTI_GPU"
|
||||
|
||||
@@ -243,3 +243,18 @@ class AxolotlTrainingMixins:
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
|
||||
dion_learning_rate: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The learning rate for Dion"},
|
||||
)
|
||||
dion_momentum: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The momentum for Dion"},
|
||||
)
|
||||
dion_rank_fraction: float | None = field(
|
||||
default=None,
|
||||
)
|
||||
dion_rank_multiple_of: int | None = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -26,9 +26,11 @@ import traceback
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -74,8 +76,8 @@ class BasePlugin:
|
||||
def __init__(self):
|
||||
"""Initializes the BasePlugin."""
|
||||
|
||||
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration.
|
||||
def register(self, cfg: dict): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration as an unparsed dict.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugin.
|
||||
@@ -641,3 +643,24 @@ class BaseOptimizerFactory:
|
||||
self, opt_model, training_args, **optimizer_kwargs
|
||||
) -> Optimizer | None:
|
||||
pass
|
||||
|
||||
# duplicated from transformers
|
||||
def get_decay_parameter_names(self, model) -> list[str]:
|
||||
"""
|
||||
Get all parameter names that weight decay will be applied to.
|
||||
|
||||
This function filters out parameters in two ways:
|
||||
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
|
||||
2. By parameter name patterns (containing 'bias', or variation of 'norm')
|
||||
"""
|
||||
forbidden_name_patterns = [
|
||||
r"bias",
|
||||
r"layernorm",
|
||||
r"rmsnorm",
|
||||
r"(?:^|\.)norm(?:$|\.)",
|
||||
r"_norm(?:$|\.)",
|
||||
]
|
||||
decay_parameters = get_parameter_names(
|
||||
model, [nn.LayerNorm], forbidden_name_patterns
|
||||
)
|
||||
return decay_parameters
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,6 +31,7 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- arcee
|
||||
- cohere
|
||||
- cohere2
|
||||
- gemma
|
||||
@@ -41,13 +42,17 @@ plugins:
|
||||
- gemma3n_text
|
||||
- glm
|
||||
- glm4
|
||||
- gpt_oss
|
||||
- granite
|
||||
- granitemoe
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
- mllama
|
||||
- phi
|
||||
- phi3
|
||||
|
||||
@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -284,12 +284,12 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
target_token_ids = prompt.pop("target_token_ids")
|
||||
target_token_ids = prompt.get("target_token_ids", None)
|
||||
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
if target_token_ids is not None:
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import Callable
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
from .geglu import geglu_backward, geglu_forward
|
||||
from .quantize import dequantize
|
||||
@@ -25,6 +26,7 @@ def get_lora_parameters(
|
||||
proj: nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
QuantState | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
@@ -37,39 +39,54 @@ def get_lora_parameters(
|
||||
proj: The projection module to extract parameters from.
|
||||
|
||||
Returns:
|
||||
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
||||
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
||||
available.
|
||||
A tuple containing the base weights, quantization state, LoRA A and B weights,
|
||||
scaling factor, and base layer bias. Quant state, weights, and bias may be
|
||||
`None` if not available.
|
||||
"""
|
||||
# For DPO or disabled adapters
|
||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||
W = base_layer.weight
|
||||
b = base_layer.bias
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
return W, quant_state, None, None, None
|
||||
return W, b, quant_state, None, None, None
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
if hasattr(proj, "active_adapters")
|
||||
else proj.active_adapter
|
||||
)
|
||||
A = proj.lora_A[active_adapter].weight
|
||||
B = proj.lora_B[active_adapter].weight
|
||||
|
||||
linear_A = proj.lora_A[active_adapter]
|
||||
linear_B = proj.lora_B[active_adapter]
|
||||
|
||||
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
|
||||
# We fuse linear layers + LoRA adapters calculations into a single
|
||||
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
|
||||
# Note that we don't apply resharding later in this module (it gets messy quickly),
|
||||
# but LoRA parameters are generally small enough that this is not an issue.
|
||||
if isinstance(linear_A.weight, DTensor):
|
||||
linear_A.unshard()
|
||||
linear_B.unshard()
|
||||
|
||||
A = linear_A.weight
|
||||
B = linear_B.weight
|
||||
s = proj.scaling[active_adapter]
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
return W, quant_state, A, B, s
|
||||
return W, b, quant_state, A, B, s
|
||||
|
||||
|
||||
def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
b: torch.Tensor | None,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
s: float | None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -90,20 +107,22 @@ def matmul_lora(
|
||||
dtype = X.dtype
|
||||
W = dequantize(W.t(), W_quant)
|
||||
|
||||
reshape = False
|
||||
if X.dim() == 3:
|
||||
batch, seq_len, _ = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
reshape = True
|
||||
else:
|
||||
reshape = False
|
||||
|
||||
out = torch.matmul(X, W, out=out)
|
||||
if W_quant is not None:
|
||||
del W
|
||||
|
||||
if A is not None:
|
||||
A, B = A.t(), B.t()
|
||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
|
||||
out += s * X @ A @ B
|
||||
|
||||
if b is not None:
|
||||
out += b
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
||||
@@ -117,17 +136,20 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
ctx,
|
||||
X: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
gate_quant: object | None,
|
||||
gate_bias: torch.Tensor | None,
|
||||
gate_quant: QuantState | None,
|
||||
gate_A: torch.Tensor | None,
|
||||
gate_B: torch.Tensor | None,
|
||||
gate_scale: float,
|
||||
up_weight: torch.Tensor,
|
||||
up_quant: object | None,
|
||||
up_bias: torch.Tensor | None,
|
||||
up_quant: QuantState | None,
|
||||
up_A: torch.Tensor | None,
|
||||
up_B: torch.Tensor | None,
|
||||
up_scale: float,
|
||||
down_weight: torch.Tensor,
|
||||
down_quant: object | None,
|
||||
down_bias: torch.Tensor | None,
|
||||
down_quant: QuantState | None,
|
||||
down_A: torch.Tensor | None,
|
||||
down_B: torch.Tensor | None,
|
||||
down_scale: float,
|
||||
@@ -142,20 +164,22 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
ctx: Autograd context
|
||||
X: Input features
|
||||
gate_weight: Gate projection weight
|
||||
gate_bias: Gate projection bias
|
||||
gate_quant: Gate quantization state
|
||||
gate_A: Gate LoRA A matrix
|
||||
gate_B: Gate LoRA B matrix
|
||||
gate_scale: Gate LoRA scale
|
||||
up_weight: Up-projection weight
|
||||
up_quant: Up-projection quantization state
|
||||
up_A: Up-projection LoRA A matrix
|
||||
up_B: Up-projection LoRA B matrix
|
||||
up_scale: Up-projection LoRA scale
|
||||
down_weight: Down-projection weight
|
||||
down_quant: Down-projection quantization state
|
||||
down_A: Down-projection LoRA A matrix
|
||||
down_B: Down-projection LoRA B matrix
|
||||
down_scale: Down-projection LoRA scale
|
||||
up_weight: Up projection weight
|
||||
up_quant: Up projection quantization state
|
||||
up_A: Up projection LoRA A matrix
|
||||
up_B: Up projection LoRA B matrix
|
||||
up_scale: Up projection LoRA scale
|
||||
down_weight: Down projection weight
|
||||
down_bias: Down projection bias
|
||||
down_quant: Down projection quantization state
|
||||
down_A: Down projection LoRA A matrix
|
||||
down_B: Down projection LoRA B matrix
|
||||
down_scale: Down projection LoRA scale
|
||||
activation_fn: Forward activation function
|
||||
activation_fn_backward: Backward activation function
|
||||
inplace: Whether to perform operations in-place
|
||||
@@ -164,15 +188,17 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
Output transformed by multi-layer perceptron and activation function
|
||||
"""
|
||||
# Compute projections
|
||||
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
||||
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
||||
gate = matmul_lora(
|
||||
X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
|
||||
)
|
||||
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
|
||||
|
||||
# Activation
|
||||
hidden = activation_fn(gate, up)
|
||||
|
||||
# Down projection
|
||||
output = matmul_lora(
|
||||
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
||||
hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
|
||||
)
|
||||
|
||||
# Save for backward
|
||||
@@ -195,22 +221,26 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Performs backward pass computation for LoRA MLP.
|
||||
@@ -222,7 +252,7 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
Returns:
|
||||
Tuple containing gradients for all inputs from forward pass:
|
||||
- Input gradient tensor (or `None`)
|
||||
- `None` for weights/quantization states
|
||||
- `None` for weights/biases/quantization states
|
||||
- LoRA A/B matrix gradients (or `None`)
|
||||
- `None` for scaling factors
|
||||
- `None` for activation functions and flags
|
||||
@@ -265,9 +295,10 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
dtype = X.dtype
|
||||
|
||||
# Down projection
|
||||
DW = matmul_lora(
|
||||
grad_down = matmul_lora(
|
||||
grad_output,
|
||||
down_weight.t(),
|
||||
None,
|
||||
down_quant,
|
||||
down_B,
|
||||
down_A,
|
||||
@@ -275,7 +306,7 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
)
|
||||
|
||||
# Activation backward
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
|
||||
|
||||
# Initialize and compute LoRA gradients
|
||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||
@@ -315,8 +346,8 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||
|
||||
# Gate projection gradients
|
||||
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
||||
dX += grad_gate @ gate_weight.t()
|
||||
gate_weight = dequantize(gate_weight, gate_quant)
|
||||
dX += grad_gate @ gate_weight
|
||||
del gate_weight
|
||||
|
||||
if gate_A is not None and gate_B is not None:
|
||||
@@ -334,22 +365,26 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
dX,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_gate_A.t() if d_gate_A is not None else None,
|
||||
d_gate_B.t() if d_gate_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_up_A.t() if d_up_A is not None else None,
|
||||
d_up_B.t() if d_up_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_down_A.t() if d_down_A is not None else None,
|
||||
d_down_B.t() if d_down_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@@ -364,23 +399,26 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateb,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upb,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downb,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
@@ -404,22 +442,25 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateb,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upb,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downb,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
@@ -446,16 +487,19 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor | None,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
k_weight: torch.Tensor,
|
||||
k_bias: torch.Tensor | None,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
v_weight: torch.Tensor,
|
||||
v_bias: torch.Tensor | None,
|
||||
v_quant: QuantState | None,
|
||||
v_A: torch.Tensor | None,
|
||||
v_B: torch.Tensor | None,
|
||||
@@ -469,16 +513,19 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
q_weight: Query projection weight
|
||||
q_bias: Query projection bias
|
||||
q_quant: Query quantization state
|
||||
q_A: Query LoRA A matrix
|
||||
q_B: Query LoRA B matrix
|
||||
q_scale: Query LoRA scale
|
||||
k_weight: Key projection weight
|
||||
k_bias: Key projection bias
|
||||
k_quant: Key quantization state
|
||||
k_A: Key LoRA A matrix
|
||||
k_B: Key LoRA B matrix
|
||||
k_scale: Key LoRA scale
|
||||
v_weight: Value projection weight
|
||||
v_bias: Value projection bias
|
||||
v_quant: Value quantization state
|
||||
v_A: Value LoRA A matrix
|
||||
v_B: Value LoRA B matrix
|
||||
@@ -488,20 +535,21 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
||||
Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
|
||||
|
||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||
ctx.scales = (q_scale, k_scale, v_scale)
|
||||
ctx.quants = (q_quant, k_quant, v_quant)
|
||||
ctx.weights = (q_weight, k_weight, v_weight)
|
||||
ctx.biases = (q_bias, k_bias, v_bias)
|
||||
ctx.inplace = inplace
|
||||
|
||||
return Q, K, V
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
@@ -511,16 +559,19 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
@@ -608,31 +659,31 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
if d_B_q is not None:
|
||||
d_B_q = d_B_q.t()
|
||||
d_B_q = d_B_q.t() # type: ignore[union-attr]
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
if d_B_k is not None:
|
||||
d_B_k = d_B_k.t()
|
||||
d_B_k = d_B_k.t() # type: ignore[union-attr]
|
||||
if d_A_v is not None:
|
||||
d_A_v = d_A_v.t()
|
||||
if d_B_v is not None:
|
||||
d_B_v = d_B_v.t()
|
||||
d_B_v = d_B_v.t() # type: ignore[union-attr]
|
||||
|
||||
return (
|
||||
grad_X.view(batch, seq_len, -1),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_v,
|
||||
d_B_v,
|
||||
None,
|
||||
@@ -653,22 +704,25 @@ def apply_lora_qkv(
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
Q, K, V = LoRA_QKV.apply(
|
||||
X,
|
||||
QW,
|
||||
Qb,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
KW,
|
||||
Kb,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
VW,
|
||||
Vb,
|
||||
VW_quant,
|
||||
VA,
|
||||
VB,
|
||||
@@ -688,10 +742,11 @@ class LoRA_O(torch.autograd.Function):
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
S: float,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for output projection with LoRA.
|
||||
@@ -700,19 +755,20 @@ class LoRA_O(torch.autograd.Function):
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
W: Output projection weight
|
||||
b: Output projection bias
|
||||
W_quant: Weight quantization state
|
||||
A: LoRA A matrix
|
||||
B: LoRA B matrix
|
||||
S: LoRA scaling factor
|
||||
s: LoRA scaling factor
|
||||
|
||||
Returns:
|
||||
Output projection tensor
|
||||
Output projection result
|
||||
"""
|
||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||
XW = matmul_lora(X, W, b, W_quant, A, B, s)
|
||||
ctx.custom_saved_tensors = (
|
||||
W,
|
||||
W_quant,
|
||||
S,
|
||||
s,
|
||||
)
|
||||
ctx.save_for_backward(A, B, X)
|
||||
|
||||
@@ -727,8 +783,9 @@ class LoRA_O(torch.autograd.Function):
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
@@ -741,7 +798,7 @@ class LoRA_O(torch.autograd.Function):
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
W, W_quant, S = ctx.custom_saved_tensors
|
||||
W, W_quant, s = ctx.custom_saved_tensors
|
||||
A, B, X = ctx.saved_tensors
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
@@ -751,17 +808,19 @@ class LoRA_O(torch.autograd.Function):
|
||||
|
||||
# Weight projection
|
||||
dY_X = X.t() @ dY
|
||||
d_A = S * dY_X @ B
|
||||
d_B = S * A @ dY_X
|
||||
d_A = s * dY_X @ B
|
||||
d_B = s * A @ dY_X
|
||||
|
||||
# Get derivative for dX
|
||||
W = dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||
A, B = A.to(dtype), B.to(dtype)
|
||||
dX += s * dY @ B @ A
|
||||
|
||||
# W, b, W_quant, A, B, s
|
||||
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
|
||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
@@ -774,7 +833,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
Returns:
|
||||
Transformed output tensor
|
||||
"""
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
|
||||
|
||||
return output
|
||||
|
||||
@@ -76,6 +76,7 @@ def load_lora(
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
@@ -106,6 +107,7 @@ def load_lora(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
target_parameters=lora_target_parameters,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
layers_pattern=cfg.peft_layers_pattern,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Model loader class implementation for loading, configuring, and patching various
|
||||
models.
|
||||
"""
|
||||
Model loader class implementation for loading, configuring, and patching various models.
|
||||
"""
|
||||
|
||||
import gc
|
||||
@@ -13,7 +13,7 @@ import peft
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from accelerate import PartialState, init_empty_weights
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from peft import (
|
||||
PeftConfig,
|
||||
@@ -22,6 +22,7 @@ from peft import (
|
||||
PeftModelForCausalLM,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from torch.distributed import DeviceMesh
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForVision2Seq,
|
||||
@@ -49,7 +50,11 @@ from axolotl.loaders.utils import (
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
||||
from axolotl.utils.distributed import (
|
||||
build_parallelism_config,
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -87,6 +92,7 @@ class ModelLoader:
|
||||
|
||||
use_parallel_config: bool | None = False
|
||||
parallelism_config: ParallelismConfig | None = None
|
||||
device_mesh: DeviceMesh | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -202,6 +208,8 @@ class ModelLoader:
|
||||
self._set_device_map_config()
|
||||
if self.cfg.revision_of_model:
|
||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||
if self.cfg.use_kernels:
|
||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||
self._set_quantization_config()
|
||||
self._set_attention_config()
|
||||
|
||||
@@ -300,7 +308,10 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
# Handle DeepSpeed Zero3
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if (
|
||||
is_deepspeed_zero3_enabled()
|
||||
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||
):
|
||||
self._set_z3_leaf_modules()
|
||||
|
||||
# Apply gradient checkpointing if needed
|
||||
@@ -405,85 +416,12 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
def _set_parallel_config(self):
|
||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
self.cfg.dp_shard_size,
|
||||
self.cfg.dp_replicate_size,
|
||||
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
self.parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
||||
partial_state = PartialState()
|
||||
# fmt: off
|
||||
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
||||
self.parallelism_config
|
||||
)
|
||||
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
||||
device_mesh
|
||||
)
|
||||
# fmt: on
|
||||
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
|
||||
if parallelism_config:
|
||||
self.parallelism_config = parallelism_config
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
def _set_auto_model_loader(self):
|
||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||
@@ -565,8 +503,17 @@ class ModelLoader:
|
||||
|
||||
def _set_quantization_config(self):
|
||||
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.model_quantization_config == "Mxfp4Config":
|
||||
from transformers import Mxfp4Config
|
||||
|
||||
mxfp4_kwargs = {}
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
else:
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
@@ -601,7 +548,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
||||
"load_in_4bit", False
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -627,7 +576,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
||||
"load_in_8bit", False
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -648,7 +599,9 @@ class ModelLoader:
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.flex_attention:
|
||||
if self.cfg.attn_implementation:
|
||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||
elif self.cfg.flex_attention:
|
||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flex_attention"
|
||||
@@ -721,7 +674,7 @@ class ModelLoader:
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
self.model_kwargs["tp_plan"] = "auto"
|
||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||
self.model_kwargs["device_mesh"] = self.device_mesh
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||
|
||||
@@ -737,6 +690,18 @@ class ModelLoader:
|
||||
elif self.is_qlora_and_fsdp_enabled:
|
||||
skip_move_to_device = True
|
||||
|
||||
if (
|
||||
self.cfg.tensor_parallel_size <= 1
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and self.cfg.fsdp_version == 2
|
||||
):
|
||||
# setting device_map for TP is not supported
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
if local_rank == 0:
|
||||
self.model_kwargs["device_map"] = "cpu"
|
||||
else:
|
||||
self.model_kwargs["device_map"] = "meta"
|
||||
|
||||
if (
|
||||
self.is_qlora_and_fsdp_enabled
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
@@ -845,6 +810,9 @@ class ModelLoader:
|
||||
self.model._tp_size = self.cfg.tensor_parallel_size
|
||||
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
||||
|
||||
if self.cfg.experimental_skip_move_to_device is not None:
|
||||
skip_move_to_device = self.cfg.experimental_skip_move_to_device
|
||||
|
||||
return skip_move_to_device
|
||||
|
||||
def _set_z3_leaf_modules(self):
|
||||
|
||||
@@ -65,6 +65,7 @@ class PatchManager:
|
||||
self._patch_llama_derived_model()
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_fsdp2_bnb_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
@@ -72,11 +73,19 @@ class PatchManager:
|
||||
self._apply_voxtral_patches()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||
patch_prepare_from_posids,
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
patch_evaluation_loop,
|
||||
patch_maybe_log_save_evaluate,
|
||||
)
|
||||
|
||||
patch_prepare_from_posids()
|
||||
patch_fsdp2 = (
|
||||
self.cfg.torch_compile
|
||||
and self.cfg.fsdp_config
|
||||
and self.cfg.fsdp_version == 2
|
||||
)
|
||||
|
||||
patch_evaluation_loop(patch_fsdp2)
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
@@ -103,6 +112,14 @@ class PatchManager:
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.context_parallel_size > 1 or (
|
||||
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
||||
):
|
||||
from axolotl.monkeypatch.accelerate.parallelism_config import (
|
||||
patch_parallelism_config,
|
||||
)
|
||||
|
||||
patch_parallelism_config()
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
|
||||
@@ -260,6 +277,23 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
def _apply_fsdp2_bnb_patches(self):
|
||||
"""Apply FSDP2 BNB patches."""
|
||||
if (
|
||||
self.cfg.fsdp_config
|
||||
and str(self.cfg.fsdp_version) == "2"
|
||||
and self.cfg.adapter == "qlora"
|
||||
):
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_bnb_torch_function_patch,
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
)
|
||||
|
||||
apply_bnb_torch_function_patch()
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import (
|
||||
@@ -330,31 +364,21 @@ class PatchManager:
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def _patch_llama_flash_attention(self, packed=False):
|
||||
def _patch_llama_flash_attention(self):
|
||||
"""Apply Flash Attention patches for LLaMA models."""
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
if packed:
|
||||
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=True,
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||
)
|
||||
elif self.cfg.s2_attention:
|
||||
if self.cfg.s2_attention:
|
||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=False,
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||
use_shifted_sparse_attn=True,
|
||||
)
|
||||
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=False,
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||
)
|
||||
@@ -385,7 +409,7 @@ class PatchManager:
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
if self.cfg.flash_attention:
|
||||
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
|
||||
self._patch_llama_flash_attention()
|
||||
elif self.cfg.xformers_attention:
|
||||
self._patch_llama_xformers_attention()
|
||||
elif self.cfg.sample_packing:
|
||||
@@ -408,17 +432,12 @@ class PatchManager:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
is_xformers_swiglu_available,
|
||||
replace_llama_mlp_with_swiglu,
|
||||
replace_llama_qkv_with_fused,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||
LOG.info("Patching with SwiGLU...")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
if self.cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("Patching with fused QKV...")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
|
||||
def _apply_unsloth_patches(self, model):
|
||||
"""Apply unsloth optimization patches."""
|
||||
if self.cfg.unsloth_lora_mlp:
|
||||
|
||||
@@ -7,6 +7,7 @@ import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -36,25 +37,49 @@ def fsdp2_load_full_state_dict(
|
||||
|
||||
meta_sharded_sd = model.state_dict()
|
||||
sharded_sd = {}
|
||||
for param_name, full_tensor in full_sd.items():
|
||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||
for param_name, sharded_meta_param in meta_sharded_sd.items():
|
||||
full_tensor = None
|
||||
if _accelerator.is_main_process:
|
||||
full_tensor = full_sd[param_name]
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype)
|
||||
|
||||
if hasattr(sharded_meta_param, "device_mesh"):
|
||||
device_mesh = sharded_meta_param.device_mesh
|
||||
if _accelerator.is_main_process:
|
||||
full_tensor = full_tensor.to(device_mesh.device_type)
|
||||
else:
|
||||
full_tensor = torch.empty(
|
||||
sharded_meta_param.size(),
|
||||
device=device_mesh.device_type,
|
||||
dtype=sharded_meta_param.dtype,
|
||||
)
|
||||
sharded_param = distribute_tensor(
|
||||
full_tensor,
|
||||
sharded_meta_param.device_mesh,
|
||||
device_mesh,
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
else:
|
||||
sharded_param = full_tensor
|
||||
# Non-sharded parameters
|
||||
if _accelerator.is_main_process:
|
||||
sharded_param = full_tensor.to(torch.device("cuda"))
|
||||
else:
|
||||
# broadcast manually
|
||||
sharded_param = torch.empty_like(
|
||||
sharded_meta_param,
|
||||
device=torch.device("cuda"),
|
||||
dtype=sharded_meta_param.dtype,
|
||||
)
|
||||
dist.broadcast(sharded_param, src=0)
|
||||
|
||||
if offload_to_cpu:
|
||||
sharded_param = sharded_param.cpu()
|
||||
|
||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||
|
||||
del full_tensor
|
||||
full_sd[param_name] = None
|
||||
|
||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||
end_time = time.time()
|
||||
LOG.debug(
|
||||
|
||||
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
workaround to allow parallelism config for pure CP
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from accelerate import DistributedType
|
||||
|
||||
|
||||
def _validate_accelerator(self, accelerator):
|
||||
_warnings = set()
|
||||
if not accelerator.multi_device and self.total_size == 1:
|
||||
# No distributed setup, valid parallelism config
|
||||
return
|
||||
|
||||
# We need this to ensure DDP works
|
||||
if self.total_size == 1:
|
||||
self._set_size("dp_replicate", accelerator.num_processes)
|
||||
|
||||
if self.total_size != accelerator.num_processes:
|
||||
raise ValueError(
|
||||
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
||||
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
||||
f"dp_shard_size/tp_size/cp_size."
|
||||
)
|
||||
|
||||
# allow parallelism config when not using fsdp if using pure context parallelism
|
||||
allow_parallelism_config = False
|
||||
|
||||
if (
|
||||
self.cp_size > 1 # pylint: disable=chained-comparison
|
||||
and self.dp_shard_size <= 1
|
||||
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
|
||||
):
|
||||
allow_parallelism_config = True
|
||||
|
||||
if (
|
||||
self.total_size > 1
|
||||
and not allow_parallelism_config
|
||||
and not (accelerator.is_fsdp2 or accelerator.multi_device)
|
||||
):
|
||||
raise ValueError(
|
||||
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
||||
)
|
||||
|
||||
for parallelism, size in self._sizes.items():
|
||||
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
||||
_warnings.add(
|
||||
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
||||
)
|
||||
|
||||
if _warnings and accelerator.is_main_process:
|
||||
warnings.warn(
|
||||
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def patched_is_fsdp2(self) -> bool:
|
||||
"""
|
||||
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
|
||||
"""
|
||||
# The new logic checks if fsdp_plugin exists before accessing its attributes
|
||||
return (
|
||||
self.distributed_type == DistributedType.FSDP
|
||||
and self.fsdp_plugin
|
||||
and self.fsdp_plugin.fsdp_version == 2
|
||||
)
|
||||
|
||||
|
||||
def patch_parallelism_config():
|
||||
from accelerate.accelerator import AcceleratorState, ParallelismConfig
|
||||
|
||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||
205
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
205
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
||||
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||
|
||||
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
||||
Params4bit parameters.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patched_torch_function(cls, func, types, args=(), kwargs=None):
|
||||
"""
|
||||
Patched version of Params4bit.__torch_function__ for preserving Params4bit
|
||||
class identity and attributes.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func in [torch.chunk, torch.split]:
|
||||
tensor = args[0]
|
||||
result = Parameter.__torch_function__(func, types, args, kwargs)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
return tuple(
|
||||
cls(
|
||||
data=chunk,
|
||||
requires_grad=tensor.requires_grad,
|
||||
quant_state=tensor.quant_state,
|
||||
blocksize=tensor.blocksize,
|
||||
compress_statistics=tensor.compress_statistics,
|
||||
quant_type=tensor.quant_type,
|
||||
quant_storage=tensor.quant_storage,
|
||||
module=tensor.module,
|
||||
bnb_quantized=tensor.bnb_quantized,
|
||||
)
|
||||
for chunk in result
|
||||
)
|
||||
|
||||
return cls(
|
||||
data=result,
|
||||
requires_grad=tensor.requires_grad,
|
||||
quant_state=tensor.quant_state,
|
||||
blocksize=tensor.blocksize,
|
||||
compress_statistics=tensor.compress_statistics,
|
||||
quant_type=tensor.quant_type,
|
||||
quant_storage=tensor.quant_storage,
|
||||
module=tensor.module,
|
||||
bnb_quantized=tensor.bnb_quantized,
|
||||
)
|
||||
|
||||
return Parameter.__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def apply_bnb_torch_function_patch():
|
||||
"""
|
||||
Patch Params4bit.__torch_function__ using Axolotl-style approach.
|
||||
|
||||
Returns:
|
||||
True if patching succeeded, False otherwise.
|
||||
"""
|
||||
from bitsandbytes.nn.modules import Params4bit
|
||||
|
||||
Params4bit.__torch_function__ = classmethod(patched_torch_function)
|
||||
|
||||
LOG.info("Successfully patched Params4bit.__torch_function__")
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def apply_init_sharded_param_patch():
|
||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
original_source = inspect.getsource(FSDPParam._init_sharded_param)
|
||||
original_source, _ = detab_code(original_source)
|
||||
|
||||
# Define the replacement
|
||||
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||
|
||||
patched_param_creation = """ import bitsandbytes as bnb
|
||||
if isinstance(param, bnb.nn.modules.Params4bit):
|
||||
self.sharded_param = bnb.nn.modules.Params4bit(
|
||||
data=sharded_param,
|
||||
requires_grad=param.requires_grad,
|
||||
quant_state=param.quant_state,
|
||||
blocksize=param.blocksize,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
quant_storage=param.quant_storage,
|
||||
module=param.module,
|
||||
bnb_quantized=param.bnb_quantized,
|
||||
)
|
||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||
else:
|
||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_param_creation in original_source:
|
||||
patched_source = original_source.replace(
|
||||
original_param_creation, patched_param_creation
|
||||
)
|
||||
patched_source = patched_source.replace(
|
||||
"def _init_sharded_param(",
|
||||
"def patched_init_sharded_param(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports
|
||||
module_name = FSDPParam.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
# Replace the method
|
||||
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||
|
||||
|
||||
def apply_init_unsharded_param_patch():
|
||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
|
||||
original_source, _ = detab_code(original_source)
|
||||
|
||||
# Define the replacement
|
||||
original_param_creation = """ self._unsharded_param = nn.Parameter(
|
||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||
)"""
|
||||
|
||||
patched_param_creation = """ import bitsandbytes as bnb
|
||||
local_tensor = self.sharded_param._local_tensor
|
||||
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
|
||||
self._unsharded_param = bnb.nn.modules.Params4bit(
|
||||
data=unsharded_param,
|
||||
requires_grad=self.sharded_param.requires_grad,
|
||||
quant_state=local_tensor.quant_state,
|
||||
blocksize=local_tensor.blocksize,
|
||||
compress_statistics=local_tensor.compress_statistics,
|
||||
quant_type=local_tensor.quant_type,
|
||||
quant_storage=local_tensor.quant_storage,
|
||||
module=local_tensor.module,
|
||||
bnb_quantized=local_tensor.bnb_quantized,
|
||||
)
|
||||
else:
|
||||
self._unsharded_param = nn.Parameter(
|
||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||
)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_param_creation in original_source:
|
||||
patched_source = original_source.replace(
|
||||
original_param_creation, patched_param_creation
|
||||
)
|
||||
patched_source = patched_source.replace(
|
||||
"def init_unsharded_param(",
|
||||
"def patched_init_unsharded_param(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports
|
||||
module_name = FSDPParam.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
# Replace the method
|
||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for patching")
|
||||
@@ -3,39 +3,26 @@
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaMLP,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from axolotl.monkeypatch.utils import set_module_name
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
except ImportError:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
||||
)
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
@@ -82,19 +69,6 @@ def replace_llama_mlp_with_swiglu(model):
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
def replace_llama_qkv_with_fused(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
qkv = FusedAttention(
|
||||
module.config,
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
module.v_proj,
|
||||
module.o_proj,
|
||||
)
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def patch_fa_llama_cross_entropy():
|
||||
LOG.info(
|
||||
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||
@@ -142,7 +116,6 @@ def patch_llama_rms_norm():
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
use_shifted_sparse_attn: Optional[bool] = False,
|
||||
@@ -154,16 +127,6 @@ def replace_llama_attn_with_flash_attn(
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
flashattn_forward_with_s2attn
|
||||
)
|
||||
else:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
|
||||
if packed:
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||
llama_model_forward
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
@@ -174,49 +137,6 @@ def replace_llama_attn_with_flash_attn(
|
||||
patch_llama_rms_norm()
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
"""
|
||||
Fused QKV Attention layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.init_device = next(iter(q.state_dict().values())).device
|
||||
|
||||
# define equivalent fused qkv projection
|
||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||
)
|
||||
self.o_proj = o
|
||||
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.qkv_proj.weight.data = torch.cat(
|
||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
q_proj, k_proj, v_proj = torch.split(
|
||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||
)
|
||||
|
||||
new_attn = LlamaAttention(self.config)
|
||||
new_attn.q_proj.weight.data = q_proj
|
||||
new_attn.k_proj.weight.data = k_proj
|
||||
new_attn.v_proj.weight.data = v_proj
|
||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||
|
||||
set_module_name(model, name, new_attn)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -355,576 +275,3 @@ def flashattn_forward_with_s2attn(
|
||||
.reshape(bsz, q_len, nheads, self.head_dim)
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
if isinstance(self, FusedAttention):
|
||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||
self.out_features, dim=-1
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# flash-attn v2 start
|
||||
#
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
#
|
||||
# flash-attn v2 end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[ # pylint: disable=unused-argument
|
||||
torch.LongTensor
|
||||
] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(
|
||||
*inputs,
|
||||
)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
padding_mask,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
"""
|
||||
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -156,6 +156,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
|
||||
return Llama4TextAttention
|
||||
|
||||
if model_type == "mistral3":
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention
|
||||
|
||||
return MistralAttention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
@@ -390,7 +395,6 @@ def apply_lora_kernel_patches(
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
@@ -400,7 +404,8 @@ def apply_lora_kernel_patches(
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
"Cannot patch some attention QKV projections - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
@@ -409,7 +414,6 @@ def apply_lora_kernel_patches(
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
@@ -418,14 +422,14 @@ def apply_lora_kernel_patches(
|
||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
"Cannot patch some attention output projection - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
)
|
||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
@@ -435,7 +439,8 @@ def apply_lora_kernel_patches(
|
||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
||||
"lora_magnitude_vector (DoRA)"
|
||||
)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
@@ -3,53 +3,14 @@
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||
mistral_model_forward
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral_cross_entropy():
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
@@ -57,604 +18,3 @@ def patch_mistral_cross_entropy():
|
||||
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None or sliding_window is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self: OriginalMistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
getattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def mistral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[ # pylint: disable=unused-argument
|
||||
torch.LongTensor
|
||||
] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = (
|
||||
self._gradient_checkpointing_func( # pylint: disable=protected-access
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
"""
|
||||
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -36,6 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"glm",
|
||||
"glm4",
|
||||
"smollm3",
|
||||
"gpt_oss",
|
||||
"arcee",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -18,9 +18,7 @@ from torch.distributed import DeviceMesh
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
_flash_supports_window = True
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""
|
||||
fix for FSDP2 evals when using torch.compile
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
model.eval()
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
if hasattr(model, "eval") and callable(model.eval):
|
||||
self.model.eval()
|
||||
"""
|
||||
|
||||
|
||||
def get_evaluation_loop_code() -> str:
|
||||
training_loop = inspect.getsource(Trainer.evaluation_loop)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
eval_loop = get_evaluation_loop_code()
|
||||
eval_loop, _ = detab_code(eval_loop)
|
||||
return ORIGINAL_TRAINER_CODE in eval_loop
|
||||
|
||||
|
||||
def patch_evaluation_loop_for_fsdp2():
|
||||
"""
|
||||
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
|
||||
"""
|
||||
|
||||
try:
|
||||
evaluation_loop = get_evaluation_loop_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
|
||||
evaluation_loop
|
||||
)
|
||||
evaluation_loop, _ = detab_code(evaluation_loop)
|
||||
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
|
||||
return
|
||||
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
|
||||
)
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
"def evaluation_loop(",
|
||||
"def _fixed_evaluation_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in evaluation_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||
Trainer.evaluation_loop = ( # pylint: disable=protected-access
|
||||
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""
|
||||
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
||||
|
||||
see https://github.com/huggingface/transformers/pull/39653/files
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _prepare_from_posids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
Return:
|
||||
query (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
indices_q (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input target sequence.
|
||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(
|
||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(
|
||||
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
||||
),
|
||||
)
|
||||
)
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
return (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
indices_q,
|
||||
(cu_seq_lens, cu_seq_lens),
|
||||
(max_length, max_length),
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_from_posids():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
||||
_prepare_from_posids
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"_prepare_from_posids",
|
||||
_prepare_from_posids,
|
||||
)
|
||||
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Module for patching transformers Trainer loss calculation to use nanmean.
|
||||
|
||||
This is needed for context parallelism since chunks of the input sequences may be fully
|
||||
masked and return NaNs in the loss calculation.
|
||||
|
||||
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
|
||||
the other evaluation_loop patch because we can't patch the same code twice without
|
||||
raising an OSError.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_EVAL_CODE = {
|
||||
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
|
||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
|
||||
}
|
||||
PATCHED_EVAL_CODE = {
|
||||
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
|
||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
||||
}
|
||||
|
||||
ORIGINAL_FSDP2_CODE = """
|
||||
model.eval()
|
||||
"""
|
||||
|
||||
PATCHED_FSDP2_CODE = """
|
||||
if hasattr(model, "eval") and callable(model.eval):
|
||||
self.model.eval()
|
||||
"""
|
||||
|
||||
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
||||
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
|
||||
|
||||
|
||||
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_evaluation_loop(patch_fsdp2: bool):
|
||||
"""Patch the evaluation_loop method."""
|
||||
# Check if already patched
|
||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||
LOG.info("Trainer.evaluation_loop already patched")
|
||||
return
|
||||
|
||||
# Check if the patterns exist
|
||||
try:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
except OSError:
|
||||
return
|
||||
Trainer.evaluation = evaluation_loop_source
|
||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||
|
||||
# Apply the nanmean patches
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
|
||||
)
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
|
||||
)
|
||||
|
||||
# Apply FSDP2 eval guard patch if needed
|
||||
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
|
||||
)
|
||||
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
|
||||
|
||||
# Rename the function to avoid conflicts
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
"def evaluation_loop(",
|
||||
"def axolotl_evaluation_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Get the module for necessary imports
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Import necessary items from the module
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in evaluation_loop_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Execute the imports and patched method
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||
Trainer.evaluation_loop = (
|
||||
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def check_maybe_log_save_evaluate_is_patchable() -> bool:
|
||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||
return ORIGINAL_MAYBE_CODE in maybe_log_source
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_maybe_log_save_evaluate():
|
||||
"""Patch the _maybe_log_save_evaluate method."""
|
||||
# Check if already patched
|
||||
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
|
||||
LOG.info("Trainer._maybe_log_save_evaluate already patched")
|
||||
return
|
||||
|
||||
# Check if the patterns exist
|
||||
try:
|
||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_maybe_log_save_evaluate = maybe_log_source
|
||||
maybe_log_source, _ = detab_code(maybe_log_source)
|
||||
|
||||
# Apply the patch
|
||||
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
|
||||
|
||||
# Rename the function to avoid conflicts
|
||||
maybe_log_source = maybe_log_source.replace(
|
||||
"def _maybe_log_save_evaluate(",
|
||||
"def axolotl_maybe_log_save_evaluate(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Get the module for necessary imports
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Import necessary items from the module
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in maybe_log_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Execute the imports and patched method
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821
|
||||
@@ -41,7 +41,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
field_messages: str = "messages",
|
||||
field_system: str = "system",
|
||||
field_tools: str = "tools",
|
||||
field_thinking: str = "reasoning_content",
|
||||
roles: dict[str, list[str]] | None = None,
|
||||
template_thinking_key: str | None = "reasoning_content",
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
@@ -50,8 +52,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
message_property_mappings = {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
"reasoning_content": "reasoning_content",
|
||||
}
|
||||
if template_thinking_key and field_thinking:
|
||||
message_property_mappings[template_thinking_key] = field_thinking
|
||||
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
@@ -74,10 +77,12 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.field_messages = field_messages
|
||||
self.field_system = field_system
|
||||
self.field_tools = field_tools
|
||||
self.field_thinking = field_thinking
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin | None = processor
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_kwargs = chat_template_kwargs or {}
|
||||
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
@@ -742,7 +747,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
# get the thinking content
|
||||
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
||||
transformed_message["reasoning_content"] = thinking_content.strip()
|
||||
transformed_message[self.prompter.template_thinking_key] = (
|
||||
thinking_content.strip()
|
||||
)
|
||||
|
||||
# take remainder of the content
|
||||
# strip whitespace from beginning of the remainder (thinking tokens)
|
||||
@@ -953,6 +960,10 @@ class StrategyLoader:
|
||||
None,
|
||||
),
|
||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
|
||||
"template_thinking_key": dataset_config.get(
|
||||
"template_thinking_key", "reasoning_content"
|
||||
),
|
||||
"roles": dataset_config.get("roles"),
|
||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
|
||||
@@ -218,6 +218,7 @@ def execute_training(
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
gather_outputs=cfg.rl is RLType.GRPO,
|
||||
device_mesh=trainer.accelerator.torch_device_mesh,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -274,7 +275,7 @@ def save_trained_model(
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
||||
if cfg.fsdp_config or cfg.fsdp:
|
||||
if cfg.fsdp_config.final_state_dict_type:
|
||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||
@@ -566,6 +567,10 @@ def train(
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# clear cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save the trained model and cleanup
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
create_model_card(cfg, trainer)
|
||||
|
||||
@@ -161,6 +161,8 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
squash_position_ids: bool = False
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features: List[List[dict]] = [features]
|
||||
@@ -176,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
elif feature == "position_ids" and self.squash_position_ids:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
# concatenate, get total length and create arange of new total position ids
|
||||
position_ids = np.concatenate(arrays)
|
||||
total_length = position_ids.shape[0]
|
||||
position_ids = np.arange(total_length)
|
||||
out_features[i][feature] = position_ids
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
|
||||
@@ -5,8 +5,8 @@ import inspect
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from torch import nn
|
||||
from torch.distributed import DeviceMesh
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
@@ -194,6 +194,7 @@ class SequenceParallelContextManager:
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
gather_outputs: bool,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
):
|
||||
self.models = models
|
||||
self.context_parallel_size = context_parallel_size
|
||||
@@ -201,6 +202,7 @@ class SequenceParallelContextManager:
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self.gather_outputs = gather_outputs
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
self._register_ring_attn()
|
||||
|
||||
@@ -240,9 +242,8 @@ class SequenceParallelContextManager:
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
partial_state = PartialState()
|
||||
register_ring_attn_from_device_mesh(
|
||||
device_mesh=partial_state.device_mesh,
|
||||
device_mesh=self.device_mesh,
|
||||
context_parallel_dim=("cp",),
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Data handling specific to SFT."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
@@ -104,6 +105,9 @@ def _prepare_standard_dataset(
|
||||
finally:
|
||||
loader.cleanup()
|
||||
|
||||
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
|
||||
return train_dataset, eval_dataset, -1, prompters
|
||||
|
||||
# Validate sample packing configuration for evaluation
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
|
||||
@@ -8,6 +8,7 @@ from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import ParallelismConfig
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
@@ -50,7 +51,10 @@ def init_distributed_state():
|
||||
global distributed_state # pylint: disable=global-statement
|
||||
if distributed_state is None:
|
||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
try:
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def get_distributed_state() -> PartialState | None:
|
||||
@@ -290,3 +294,77 @@ def reduce_and_broadcast(fn1, fn2):
|
||||
# Use compute_and_broadcast to compute the reduced value on the main process
|
||||
# and then broadcast it to all ranks
|
||||
return compute_and_broadcast(lambda: fn2(gathered_values))
|
||||
|
||||
|
||||
def build_parallelism_config(cfg):
|
||||
pc_kwargs = _get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
cfg.tensor_parallel_size,
|
||||
cfg.context_parallel_size,
|
||||
cfg.dp_shard_size,
|
||||
cfg.dp_replicate_size,
|
||||
bool(cfg.fsdp or cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = parallelism_config.build_device_mesh("cuda")
|
||||
|
||||
return parallelism_config, device_mesh
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
28
src/axolotl/utils/import_helper.py
Normal file
28
src/axolotl/utils/import_helper.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Helper for importing modules from strings
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
def get_cls_from_module_str(module_str: str):
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
if not isinstance(module_str, str) or not module_str.strip():
|
||||
raise ValueError("module_str must be a non-empty string")
|
||||
|
||||
parts = module_str.split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid module string format: {module_str}")
|
||||
|
||||
try:
|
||||
cls_name = parts[-1]
|
||||
module_path = ".".join(parts[:-1])
|
||||
mod = importlib.import_module(module_path)
|
||||
mod_cls = getattr(mod, cls_name)
|
||||
return mod_cls
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import module '{module_path}': {e}") from e
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Class '{cls_name}' not found in module '{module_path}': {e}"
|
||||
) from e
|
||||
@@ -4,6 +4,7 @@ import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
@@ -45,8 +46,10 @@ class RexLR(LRScheduler):
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
|
||||
initial_lr = group["lr"]
|
||||
if isinstance(initial_lr, Tensor):
|
||||
initial_lr = initial_lr.clone()
|
||||
group.setdefault("initial_lr", initial_lr)
|
||||
# Pass self.last_step as last_epoch to the parent.
|
||||
super().__init__(optimizer, last_epoch=self.last_step)
|
||||
|
||||
|
||||
@@ -110,6 +110,13 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "module to custom trainer class to use for training"
|
||||
},
|
||||
)
|
||||
|
||||
rl: RLType | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -531,12 +538,6 @@ class AxolotlInputConfig(
|
||||
"description": "Whether to use flash-attention rms norm implementation - advanced use only"
|
||||
},
|
||||
)
|
||||
flash_attn_fuse_qkv: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to fuse QKV into a single operation"
|
||||
},
|
||||
)
|
||||
flash_attn_fuse_mlp: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -550,6 +551,13 @@ class AxolotlInputConfig(
|
||||
|
||||
eager_attention: bool | None = None
|
||||
|
||||
attn_implementation: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Specify a custom attention implementation, used mostly for kernels."
|
||||
},
|
||||
)
|
||||
|
||||
unsloth_cross_entropy_loss: bool | None = None
|
||||
unsloth_lora_mlp: bool | None = None
|
||||
unsloth_lora_qkv: bool | None = None
|
||||
|
||||
@@ -118,6 +118,18 @@ class SFTDataset(BaseModel):
|
||||
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
||||
},
|
||||
)
|
||||
field_thinking: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Key containing the reasoning trace (default: "reasoning_content").'
|
||||
},
|
||||
)
|
||||
template_thinking_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The key the chat template expects that indicates the reasoning trace."
|
||||
},
|
||||
)
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
|
||||
@@ -79,6 +79,7 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
adopt_adamw = "adopt_adamw"
|
||||
came_pytorch = "came_pytorch"
|
||||
muon = "muon"
|
||||
dion = "dion"
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Pydantic models for model input / output, etc. configuration"""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -62,6 +64,28 @@ class ModelInputConfig(BaseModel):
|
||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||
)
|
||||
|
||||
experimental_skip_move_to_device: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Don't move the model to the device before sharding. "
|
||||
"This is an experimental feature that may be included in the future as the default."
|
||||
},
|
||||
)
|
||||
|
||||
use_kernels: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
|
||||
)
|
||||
|
||||
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Model loading quantization config"},
|
||||
)
|
||||
model_quantization_config_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "kwargs for model quantization config"},
|
||||
)
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
def hint_trust_remote_code(cls, trust_remote_code):
|
||||
|
||||
@@ -54,6 +54,7 @@ class LoraConfig(BaseModel):
|
||||
lora_alpha: int | None = None
|
||||
lora_fan_in_fan_out: bool | None = None
|
||||
lora_target_modules: str | list[str] | None = None
|
||||
lora_target_parameters: str | list[str] | None = None
|
||||
lora_target_linear: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "If true, will target all linear modules"},
|
||||
|
||||
@@ -138,6 +138,26 @@ class HyperparametersConfig(BaseModel):
|
||||
adam_beta3: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
|
||||
)
|
||||
|
||||
dion_lr: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "Dion Optimizer learning rate"}
|
||||
)
|
||||
dion_momentum: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "Dion Optimizer momentum"}
|
||||
)
|
||||
dion_rank_fraction: float | None = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension."
|
||||
},
|
||||
)
|
||||
dion_rank_multiple_of: int | None = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
"description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding."
|
||||
},
|
||||
)
|
||||
|
||||
max_grad_norm: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
|
||||
)
|
||||
|
||||
@@ -559,20 +559,6 @@ class LoRAValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_8bit(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_axolotl_unsloth(cls, data):
|
||||
@@ -591,9 +577,7 @@ class LoRAValidationMixin:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fused_lora(self):
|
||||
if self.adapter in ["lora", "qlora"] and (
|
||||
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
|
||||
):
|
||||
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
||||
return self
|
||||
|
||||
@@ -619,7 +603,7 @@ class LoRAValidationMixin:
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernel_8bit(cls, data):
|
||||
def check_lora_kernels_8bit(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
@@ -627,20 +611,36 @@ class LoRAValidationMixin:
|
||||
):
|
||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with 8-bit LoRA a the moment."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernel_rl(cls, data):
|
||||
def check_lora_kernels_dora(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("peft_use_dora"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with DoRA at the moment."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_rl(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("rl"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with RL at the moment."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -972,6 +972,16 @@ class SystemValidationMixin:
|
||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_model_quantization_config_vs_bnb(cls, data):
|
||||
if data.get("model_quantization_config"):
|
||||
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
||||
raise ValueError(
|
||||
"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_npu_config(cls, data):
|
||||
@@ -1137,6 +1147,19 @@ class ModelCompatibilityValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gpt_oss_fsdp_loading(cls, data):
|
||||
if data.get("model_quantization_config", "") == "Mxfp4Config":
|
||||
if (
|
||||
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
|
||||
is True
|
||||
):
|
||||
raise ValueError(
|
||||
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class ComplexValidationMixin:
|
||||
"""Complex validation methods that involve multiple systems."""
|
||||
@@ -1182,7 +1205,7 @@ class ComplexValidationMixin:
|
||||
"ReLoRA is not compatible with the one_cycle scheduler"
|
||||
)
|
||||
|
||||
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
|
||||
if self.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
return self
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -597,6 +596,25 @@ def setup_fsdp_envs(cfg):
|
||||
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
|
||||
|
||||
|
||||
def setup_parallelism_envs(cfg):
|
||||
set_accelerate_parallelism_config = False
|
||||
if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size)
|
||||
if cfg.dp_shard_size and cfg.dp_shard_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size)
|
||||
if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size)
|
||||
if cfg.context_parallel_size and cfg.context_parallel_size > 1:
|
||||
set_accelerate_parallelism_config = True
|
||||
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
|
||||
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
|
||||
if set_accelerate_parallelism_config:
|
||||
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
||||
|
||||
|
||||
def prepare_optim_env(cfg):
|
||||
if not check_cuda_p2p_ib_support():
|
||||
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||
@@ -615,6 +633,7 @@ def prepare_optim_env(cfg):
|
||||
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||
setup_deepspeed_env(cfg, stage=stage)
|
||||
|
||||
setup_parallelism_envs(cfg)
|
||||
setup_torch_compile_env(cfg)
|
||||
|
||||
if cfg.fp8:
|
||||
@@ -667,8 +686,6 @@ def setup_trainer(
|
||||
"""
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
|
||||
patch_evaluation_loop_for_fsdp2()
|
||||
if cfg.rl:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||
trainer_builder.model_ref = model_ref
|
||||
|
||||
@@ -47,7 +47,9 @@ class BaseCliTest:
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock:
|
||||
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
||||
|
||||
with patch(mock_fn) as mock:
|
||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
@@ -65,8 +67,12 @@ class BaseCliTest:
|
||||
if train:
|
||||
expected.append("--shard=False")
|
||||
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
if command == "train":
|
||||
assert mock.call_args.args[0] == "accelerate"
|
||||
assert mock.call_args.args[1] == expected
|
||||
else:
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||
|
||||
@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -137,7 +137,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
@@ -152,7 +153,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -170,7 +171,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
@@ -186,7 +188,7 @@ class TestTrainCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@@ -207,7 +209,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert mock_subprocess.call_args.args[0] == "torchrun"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
# Verify launcher args
|
||||
assert "--nproc_per_node=8" in called_cmd
|
||||
# Verify axolotl args are also present
|
||||
|
||||
@@ -281,7 +281,9 @@ class TestHFRLTrainerBuilder:
|
||||
# Other settings
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
assert training_arguments.gradient_checkpointing is False
|
||||
|
||||
# TODO(wing): restore once trl releases 0.22.0
|
||||
# assert training_arguments.gradient_checkpointing is True
|
||||
|
||||
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
||||
|
||||
@@ -64,6 +64,7 @@ def sample_tensors():
|
||||
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
||||
),
|
||||
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
||||
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
|
||||
"scale": 0.5,
|
||||
"shapes": {
|
||||
"batch": batch_size,
|
||||
@@ -103,23 +104,24 @@ def mock_proj():
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
assert b.shape == (128,)
|
||||
assert A.shape == (8, 64)
|
||||
assert B.shape == (128, 8)
|
||||
assert s == 0.5
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
@@ -127,6 +129,7 @@ def test_matmul_lora(sample_tensors):
|
||||
"""Tests matmul_lora function"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
b = sample_tensors["b"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -138,19 +141,20 @@ def test_matmul_lora(sample_tensors):
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test base matmul
|
||||
out1 = matmul_lora(X, W, None, None, None, None)
|
||||
expected1 = torch.matmul(X, W.t())
|
||||
out1 = matmul_lora(X, W, b, None, None, None, None)
|
||||
matmul = torch.matmul(X, W.t())
|
||||
expected1 = matmul + b
|
||||
assert torch.allclose(out1, expected1, rtol=1e-3)
|
||||
|
||||
# Test with LoRA
|
||||
out2 = matmul_lora(X, W, None, A, B, scale)
|
||||
out2 = matmul_lora(X, W, b, None, A, B, scale)
|
||||
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
||||
expected2 = expected1 + lora_term
|
||||
expected2 = matmul + lora_term + b
|
||||
assert torch.allclose(out2, expected2, rtol=1e-3)
|
||||
|
||||
# Test 3D input reshaping
|
||||
X_3d = X.clone()
|
||||
out3 = matmul_lora(X_3d, W, None, A, B, scale)
|
||||
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
|
||||
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
|
||||
@@ -175,16 +179,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
@@ -243,16 +250,19 @@ def test_lora_mlp_with_adapters(
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
@@ -323,6 +333,7 @@ def test_lora_qkv(sample_tensors):
|
||||
X.requires_grad = True
|
||||
|
||||
# Test without LoRA adapters
|
||||
# pylint: disable=duplicate-code
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
@@ -330,16 +341,19 @@ def test_lora_qkv(sample_tensors):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -356,16 +370,19 @@ def test_lora_qkv(sample_tensors):
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
@@ -399,6 +416,7 @@ def test_lora_o(sample_tensors):
|
||||
"""Tests LoRA output projection"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
b = sample_tensors["b"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -411,7 +429,7 @@ def test_lora_o(sample_tensors):
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, None, A, B, scale)
|
||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
@@ -425,6 +443,7 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
"""Tests LoRA with quantized weights"""
|
||||
X = sample_tensors["X"] # [batch, seq, hidden]
|
||||
W = sample_tensors["W"] # [out, hidden]
|
||||
b = sample_tensors["b"] # [out]
|
||||
scale = 0.5
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -436,13 +455,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test matmul with quantization
|
||||
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
|
||||
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
|
||||
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
# Test with different batch sizes
|
||||
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
|
||||
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
|
||||
assert out2.shape == (4, 6, W.shape[0])
|
||||
assert not torch.isnan(out2).any()
|
||||
|
||||
@@ -459,11 +478,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
|
||||
"""Tests various input shapes and dimensions"""
|
||||
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
||||
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
||||
b = torch.randn(out, device="cuda", dtype=torch.float16)
|
||||
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
result = matmul_lora(X, W, None, A, B, scale)
|
||||
result = matmul_lora(X, W, b, None, A, B, scale)
|
||||
assert result.shape == (batch, seq, out)
|
||||
|
||||
|
||||
@@ -471,6 +491,7 @@ def test_gradient_flow(sample_tensors):
|
||||
"""Tests gradient flow through LoRA layers"""
|
||||
X = sample_tensors["X"].clone()
|
||||
W = sample_tensors["W"].clone()
|
||||
b = sample_tensors["b"].clone()
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -486,7 +507,7 @@ def test_gradient_flow(sample_tensors):
|
||||
B.requires_grad = True
|
||||
|
||||
# Forward pass
|
||||
out = matmul_lora(X, W, None, A, B, scale)
|
||||
out = matmul_lora(X, W, b, None, A, B, scale)
|
||||
loss = out.sum()
|
||||
|
||||
# Backward pass
|
||||
|
||||
@@ -174,6 +174,69 @@ class TestFSDP2:
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_qlora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
@@ -236,6 +299,70 @@ class TestFSDP2:
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_qlora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
|
||||
@@ -10,7 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||
from tests.e2e.utils import (
|
||||
check_tensorboard,
|
||||
require_torch_2_7_0,
|
||||
require_torch_lt_2_6_0,
|
||||
)
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -139,3 +143,71 @@ class TestMultiGPURay:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--use-ray",
|
||||
"--ray-num-workers",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
131
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
131
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Integration tests for FSDP Params4bit patches."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import pytest
|
||||
import torch
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_bnb_torch_function_patch,
|
||||
patched_torch_function,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_params4bit():
|
||||
"""Create a mock Params4bit instance with test attributes."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.requires_grad = True
|
||||
mock_instance.quant_state = "test_state"
|
||||
mock_instance.blocksize = 128
|
||||
mock_instance.compress_statistics = True
|
||||
mock_instance.quant_type = "fp4"
|
||||
mock_instance.quant_storage = "test_storage"
|
||||
mock_instance.module = "test_module"
|
||||
mock_instance.bnb_quantized = True
|
||||
return mock_instance
|
||||
|
||||
|
||||
class TestBnbTorchFunctionPatch:
|
||||
"""Test the Params4bit.__torch_function__ patch."""
|
||||
|
||||
def test_apply_patch(self):
|
||||
"""Test that the patch can be applied."""
|
||||
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
|
||||
apply_bnb_torch_function_patch()
|
||||
assert hasattr(mock_cls, "__torch_function__")
|
||||
assert isinstance(mock_cls.__torch_function__, classmethod)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
|
||||
"""Test that torch.chunk preserves Params4bit attributes."""
|
||||
mock_cls = Mock()
|
||||
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
|
||||
|
||||
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
|
||||
result = patched_torch_function(
|
||||
mock_cls,
|
||||
torch.chunk,
|
||||
(type(mock_params4bit),),
|
||||
args=(mock_params4bit, 2),
|
||||
)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 2
|
||||
|
||||
# Check that Params4bit constructor was called with preserved attributes
|
||||
assert mock_cls.call_count == 2
|
||||
for call in mock_cls.call_args_list:
|
||||
kwargs = call[1]
|
||||
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
|
||||
assert kwargs["quant_state"] == mock_params4bit.quant_state
|
||||
assert kwargs["blocksize"] == mock_params4bit.blocksize
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_other_functions_fallback(self, mock_params4bit):
|
||||
"""Test that non-chunk/split functions use Parameter fallback."""
|
||||
mock_cls = Mock()
|
||||
fallback_result = torch.tensor([5, 6, 7])
|
||||
|
||||
with patch(
|
||||
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
|
||||
) as mock_fallback:
|
||||
result = patched_torch_function(
|
||||
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
|
||||
)
|
||||
|
||||
# Should call Parameter.__torch_function__ and return its result
|
||||
mock_fallback.assert_called_once()
|
||||
assert result is fallback_result
|
||||
mock_cls.assert_not_called()
|
||||
|
||||
|
||||
class TestFSDPPatchIntegration:
|
||||
"""Test FSDP patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_all_patches_together(self):
|
||||
"""Test that all patches can be applied together."""
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
)
|
||||
|
||||
# Store original methods before patching
|
||||
original_torch_function = getattr(
|
||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
original_init_sharded = FSDPParam._init_sharded_param
|
||||
original_init_unsharded = FSDPParam.init_unsharded_param
|
||||
|
||||
# Apply patches
|
||||
apply_bnb_torch_function_patch()
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
|
||||
# Verify patches were applied
|
||||
current_torch_function = getattr(
|
||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
||||
)
|
||||
if original_torch_function is not None:
|
||||
assert (
|
||||
current_torch_function != original_torch_function
|
||||
), "Params4bit.__torch_function__ was not patched"
|
||||
else:
|
||||
assert (
|
||||
current_torch_function is not None
|
||||
), "Params4bit.__torch_function__ was not added"
|
||||
|
||||
# Check that FSDP methods were patched
|
||||
assert (
|
||||
# pylint: disable=protected-access
|
||||
FSDPParam._init_sharded_param
|
||||
!= original_init_sharded
|
||||
), "_init_sharded_param was not patched"
|
||||
assert (
|
||||
FSDPParam.init_unsharded_param != original_init_unsharded
|
||||
), "init_unsharded_param was not patched"
|
||||
@@ -29,7 +29,6 @@ class TestFusedLlama(unittest.TestCase):
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
|
||||
@@ -13,6 +13,7 @@ from .utils import (
|
||||
check_model_output_exists,
|
||||
require_torch_2_5_1,
|
||||
require_torch_2_6_0,
|
||||
require_torch_2_7_0,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
@@ -160,6 +161,49 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_7_0
|
||||
def test_dion(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "dion",
|
||||
"dion_lr": 0.01,
|
||||
"dion_momentum": 0.95,
|
||||
"lr_scheduler": "cosine",
|
||||
"weight_decay": 0.01,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
26
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
26
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Unit tests for trainer loss calc monkeypatch."""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
check_evaluation_loop_is_patchable,
|
||||
check_maybe_log_save_evaluate_is_patchable,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerLossCalc(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for trainer loss calc monkeypatch
|
||||
"""
|
||||
|
||||
def test_trainer_loss_calc_is_patchable(self):
|
||||
"""
|
||||
Test that the upstream transformers code is still patchable. This will fail if
|
||||
the patched code changes upstream.
|
||||
"""
|
||||
assert check_evaluation_loop_is_patchable()
|
||||
assert check_maybe_log_save_evaluate_is_patchable()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available
|
||||
|
||||
from axolotl.loaders import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import _get_parallel_config_kwargs
|
||||
|
||||
|
||||
class TestModelsUtils:
|
||||
@@ -193,15 +194,13 @@ class TestModelsUtils:
|
||||
is_fsdp,
|
||||
expected,
|
||||
):
|
||||
res = (
|
||||
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
)
|
||||
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
)
|
||||
|
||||
if expected[0] > 1:
|
||||
|
||||
37
tests/utils/test_import_helper.py
Normal file
37
tests/utils/test_import_helper.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
test cases for axolotl.utils.import_helper
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
|
||||
|
||||
def test_get_cls_from_module_str():
|
||||
cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer")
|
||||
assert cls.__name__ == "AxolotlTrainer"
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_empty_string():
|
||||
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||
get_cls_from_module_str("")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_whitespace_only():
|
||||
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||
get_cls_from_module_str(" ")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_invalid_format():
|
||||
with pytest.raises(ValueError, match="Invalid module string format"):
|
||||
get_cls_from_module_str("single_part")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_nonexistent_module():
|
||||
with pytest.raises(ImportError, match="Failed to import module"):
|
||||
get_cls_from_module_str("nonexistent.module.Class")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_nonexistent_class():
|
||||
with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"):
|
||||
get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")
|
||||
Reference in New Issue
Block a user