Compare commits
2 Commits
fix-previe
...
chat-templ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5198d8734 | ||
|
|
4ab6a1bd7e |
7
.github/CONTRIBUTING.md
vendored
7
.github/CONTRIBUTING.md
vendored
@@ -57,13 +57,6 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
|
|||||||
5. Push your branch to your fork on GitHub.
|
5. Push your branch to your fork on GitHub.
|
||||||
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
||||||
|
|
||||||
#### Skipping CI Checks
|
|
||||||
|
|
||||||
You can skip certain CI checks by including specific keywords in your commit messages:
|
|
||||||
|
|
||||||
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
|
|
||||||
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks
|
|
||||||
|
|
||||||
## Style Guidelines
|
## Style Guidelines
|
||||||
|
|
||||||
### Code Style
|
### Code Style
|
||||||
|
|||||||
27
.github/workflows/base.yml
vendored
27
.github/workflows/base.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.6.3
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
@@ -64,16 +64,9 @@ jobs:
|
|||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.8.0
|
pytorch: nightly
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base-nightly"
|
||||||
# - cuda: "128"
|
|
||||||
# cuda_version: 12.8.1
|
|
||||||
# cudnn_version: ""
|
|
||||||
# python_version: "3.11"
|
|
||||||
# pytorch: nightly
|
|
||||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
# dockerfile: "Dockerfile-base-nightly"
|
|
||||||
# # "next" is for release candidates of pytorch
|
# # "next" is for release candidates of pytorch
|
||||||
# - cuda: "128"
|
# - cuda: "128"
|
||||||
# cuda_version: 12.8.1
|
# cuda_version: 12.8.1
|
||||||
@@ -129,13 +122,6 @@ jobs:
|
|||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
- cuda: "126"
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -143,13 +129,6 @@ jobs:
|
|||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.8.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
23
.github/workflows/main.yml
vendored
23
.github/workflows/main.yml
vendored
@@ -24,13 +24,12 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -98,12 +97,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras: vllm
|
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
@@ -157,18 +150,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras:
|
|
||||||
is_latest:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
axolotl_extras: vllm
|
|
||||||
is_latest: true
|
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@@ -105,8 +105,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
|
||||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
@@ -180,8 +179,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/patched/
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ default_language_version:
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v6.0.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -23,11 +23,11 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
rev: v3.3.8
|
rev: v3.3.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.17.1
|
rev: v1.17.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ datasets:
|
|||||||
| `flash_attention` | `false` | Use flash attention |
|
| `flash_attention` | `false` | Use flash attention |
|
||||||
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
||||||
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
||||||
|
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
|
||||||
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
||||||
| `sdp_attention` | `false` | Use scaled dot product |
|
| `sdp_attention` | `false` | Use scaled dot product |
|
||||||
| `s2_attention` | `false` | Use shifted sparse attention |
|
| `s2_attention` | `false` | Use shifted sparse attention |
|
||||||
|
|||||||
@@ -296,6 +296,7 @@
|
|||||||
# flash_attention:
|
# flash_attention:
|
||||||
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||||
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||||
|
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||||
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||||
# # Whether to use scaled-dot-product attention
|
# # Whether to use scaled-dot-product attention
|
||||||
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
@@ -540,6 +541,7 @@ xformers_attention: ${XFORMERS_ATTENTION}
|
|||||||
flash_attention: ${FLASH_ATTENTION}
|
flash_attention: ${FLASH_ATTENTION}
|
||||||
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
||||||
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
||||||
|
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
|
||||||
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
||||||
sdp_attention: ${SDP_ATTENTION}
|
sdp_attention: ${SDP_ATTENTION}
|
||||||
s2_attention: ${S2_ATTENTION}
|
s2_attention: ${S2_ATTENTION}
|
||||||
|
|||||||
10
CITATION.cff
10
CITATION.cff
@@ -1,10 +0,0 @@
|
|||||||
cff-version: 1.2.0
|
|
||||||
type: software
|
|
||||||
title: "Axolotl: Post-Training for AI Models"
|
|
||||||
message: "If you use this software, please cite it as below."
|
|
||||||
authors:
|
|
||||||
- name: "Axolotl maintainers and contributors"
|
|
||||||
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
|
|
||||||
url: "https://axolotl.ai/"
|
|
||||||
license: Apache-2.0
|
|
||||||
date-released: "2023-05-30"
|
|
||||||
33
README.md
33
README.md
@@ -25,28 +25,17 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/07:
|
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
|
||||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
|
||||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
|
||||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
|
||||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
|
||||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
|
||||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Expand older updates</summary>
|
|
||||||
|
|
||||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||||
|
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||||
|
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
## ✨ Overview
|
## ✨ Overview
|
||||||
|
|
||||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||||
@@ -149,20 +138,6 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
|
|||||||
|
|
||||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||||
|
|
||||||
## 📝 Citing Axolotl
|
|
||||||
|
|
||||||
If you use Axolotl in your research or projects, please cite it as follows:
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@software{axolotl,
|
|
||||||
title = {Axolotl: Post-Training for AI Models},
|
|
||||||
author = {{Axolotl maintainers and contributors}},
|
|
||||||
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
|
||||||
license = {Apache-2.0},
|
|
||||||
year = {2023}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📜 License
|
## 📜 License
|
||||||
|
|
||||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||||
|
|||||||
10
TODO.md
Normal file
10
TODO.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# todo list
|
||||||
|
|
||||||
|
- [] Validation of parameters for combinations that won't work
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## things that are known not to work
|
||||||
|
|
||||||
|
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
|
||||||
|
- adamw_bnb_8bit doesn't play well with FSDP offload
|
||||||
@@ -274,7 +274,6 @@ website:
|
|||||||
- docs/dataset_preprocessing.qmd
|
- docs/dataset_preprocessing.qmd
|
||||||
- docs/multipack.qmd
|
- docs/multipack.qmd
|
||||||
- docs/mixed_precision.qmd
|
- docs/mixed_precision.qmd
|
||||||
- docs/optimizers.qmd
|
|
||||||
|
|
||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -212,11 +212,10 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
|||||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Example config for Llama4:
|
|
||||||
```yaml
|
```yaml
|
||||||
chat_template: llama4
|
chat_template: llama4
|
||||||
datasets:
|
datasets:
|
||||||
- path: Nanobit/text-tools-2k-test
|
- path: ...
|
||||||
type: chat_template
|
type: chat_template
|
||||||
# field_tools: tools # default is `tools`
|
# field_tools: tools # default is `tools`
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
---
|
# N-D Parallelism
|
||||||
title: "N-D Parallelism (Beta)"
|
|
||||||
---
|
|
||||||
|
|
||||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||||
|
|
||||||
@@ -73,10 +71,6 @@ Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size
|
|||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
|
|
||||||
:::
|
|
||||||
|
|
||||||
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
||||||
- You want FSDP within each node and DDP across nodes.
|
- You want FSDP within each node and DDP across nodes.
|
||||||
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
||||||
@@ -101,7 +95,7 @@ This matrix describes how different parallelism methods can be combined in Axolo
|
|||||||
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
||||||
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
||||||
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
||||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
|
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
|
||||||
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
||||||
|
|
||||||
- `tp_size` refers to `tensor_parallel_size`
|
- `tp_size` refers to `tensor_parallel_size`
|
||||||
|
|||||||
@@ -1,129 +0,0 @@
|
|||||||
---
|
|
||||||
title: Optimizers
|
|
||||||
description: Configuring optimizers
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
|
|
||||||
|
|
||||||
Here is a list of optimizers supported by transformers as of `v4.54.0`:
|
|
||||||
|
|
||||||
- `adamw_torch`
|
|
||||||
- `adamw_torch_fused`
|
|
||||||
- `adamw_torch_xla`
|
|
||||||
- `adamw_torch_npu_fused`
|
|
||||||
- `adamw_apex_fused`
|
|
||||||
- `adafactor`
|
|
||||||
- `adamw_anyprecision`
|
|
||||||
- `adamw_torch_4bit`
|
|
||||||
- `adamw_torch_8bit`
|
|
||||||
- `ademamix`
|
|
||||||
- `sgd`
|
|
||||||
- `adagrad`
|
|
||||||
- `adamw_bnb_8bit`
|
|
||||||
- `adamw_8bit` # alias for adamw_bnb_8bit
|
|
||||||
- `ademamix_8bit`
|
|
||||||
- `lion_8bit`
|
|
||||||
- `lion_32bit`
|
|
||||||
- `paged_adamw_32bit`
|
|
||||||
- `paged_adamw_8bit`
|
|
||||||
- `paged_ademamix_32bit`
|
|
||||||
- `paged_ademamix_8bit`
|
|
||||||
- `paged_lion_32bit`
|
|
||||||
- `paged_lion_8bit`
|
|
||||||
- `rmsprop`
|
|
||||||
- `rmsprop_bnb`
|
|
||||||
- `rmsprop_bnb_8bit`
|
|
||||||
- `rmsprop_bnb_32bit`
|
|
||||||
- `galore_adamw`
|
|
||||||
- `galore_adamw_8bit`
|
|
||||||
- `galore_adafactor`
|
|
||||||
- `galore_adamw_layerwise`
|
|
||||||
- `galore_adamw_8bit_layerwise`
|
|
||||||
- `galore_adafactor_layerwise`
|
|
||||||
- `lomo`
|
|
||||||
- `adalomo`
|
|
||||||
- `grokadamw`
|
|
||||||
- `schedule_free_radam`
|
|
||||||
- `schedule_free_adamw`
|
|
||||||
- `schedule_free_sgd`
|
|
||||||
- `apollo_adamw`
|
|
||||||
- `apollo_adamw_layerwise`
|
|
||||||
- `stable_adamw`
|
|
||||||
|
|
||||||
|
|
||||||
## Custom Optimizers
|
|
||||||
|
|
||||||
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
|
|
||||||
|
|
||||||
### optimi_adamw
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
optimizer: optimi_adamw
|
|
||||||
```
|
|
||||||
|
|
||||||
### ao_adamw_4bit
|
|
||||||
|
|
||||||
Deprecated: Please use `adamw_torch_4bit`.
|
|
||||||
|
|
||||||
### ao_adamw_8bit
|
|
||||||
|
|
||||||
Deprecated: Please use `adamw_torch_8bit`.
|
|
||||||
|
|
||||||
### ao_adamw_fp8
|
|
||||||
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
optimizer: ao_adamw_fp8
|
|
||||||
```
|
|
||||||
|
|
||||||
### adopt_adamw
|
|
||||||
|
|
||||||
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
|
|
||||||
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
optimizer: adopt_adamw
|
|
||||||
```
|
|
||||||
|
|
||||||
### came_pytorch
|
|
||||||
|
|
||||||
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
|
|
||||||
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
optimizer: came_pytorch
|
|
||||||
|
|
||||||
# optional args (defaults below)
|
|
||||||
adam_beta1: 0.9
|
|
||||||
adam_beta2: 0.999
|
|
||||||
adam_beta3: 0.9999
|
|
||||||
adam_epsilon: 1e-30
|
|
||||||
adam_epsilon2: 1e-16
|
|
||||||
```
|
|
||||||
|
|
||||||
### muon
|
|
||||||
|
|
||||||
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
|
|
||||||
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
optimizer: muon
|
|
||||||
```
|
|
||||||
|
|
||||||
### dion
|
|
||||||
|
|
||||||
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
|
||||||
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
|
||||||
|
|
||||||
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
|
|
||||||
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
|
|
||||||
Note: Implementation written for PyTorch 2.7+ for DTensor
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
optimizer: dion
|
|
||||||
dion_lr: 0.01
|
|
||||||
dion_momentum: 0.95
|
|
||||||
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
|
|
||||||
```
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
# Finetune ArceeAI's AFM with Axolotl
|
|
||||||
|
|
||||||
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl train examples/arcee/afm-4.5b-qlora.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
This config uses about 7.8GiB VRAM.
|
|
||||||
|
|
||||||
Let us know how it goes. Happy finetuning! 🚀
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
|
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
|
||||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
|
||||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
base_model: arcee-ai/AFM-4.5B
|
|
||||||
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
load_in_8bit: false
|
|
||||||
load_in_4bit: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
|
||||||
type: chat_template
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.1
|
|
||||||
output_dir: ./outputs/lora-out
|
|
||||||
|
|
||||||
adapter: qlora
|
|
||||||
lora_model_dir:
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
lora_r: 32
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.05
|
|
||||||
lora_target_linear: true
|
|
||||||
lora_target_modules:
|
|
||||||
- gate_proj
|
|
||||||
- down_proj
|
|
||||||
- up_proj
|
|
||||||
- q_proj
|
|
||||||
- v_proj
|
|
||||||
- k_proj
|
|
||||||
- o_proj
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
evals_per_epoch: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -47,6 +47,7 @@ logging_steps: 1
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
|
flash_attn_fuse_qkv: false
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -10,14 +10,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Run the finetuning example:
|
||||||
|
|||||||
@@ -1,52 +0,0 @@
|
|||||||
# ND Parallelism Examples
|
|
||||||
|
|
||||||
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
2. Run the command below:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
|
|
||||||
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
|
|
||||||
|
|
||||||
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
|
|
||||||
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Example Configurations
|
|
||||||
|
|
||||||
### Single Node (8 GPUs)
|
|
||||||
|
|
||||||
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
|
|
||||||
- Uses all 3 parallelism dimensions on a single node
|
|
||||||
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
dp_shard_size: 2 # FSDP across 2 GPUs
|
|
||||||
tensor_parallel_size: 2 # TP across 2 GPUs
|
|
||||||
context_parallel_size: 2 # CP across 2 GPUs
|
|
||||||
# Total: 2 × 2 × 2 = 8 GPUs
|
|
||||||
```
|
|
||||||
|
|
||||||
### Multi-Node
|
|
||||||
|
|
||||||
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
|
|
||||||
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
|
|
||||||
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
dp_shard_size: 4 # FSDP within each 4-GPU group
|
|
||||||
tensor_parallel_size: 2 # TP within each node
|
|
||||||
dp_replicate_size: 2 # DDP across 2 groups
|
|
||||||
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Learn More
|
|
||||||
|
|
||||||
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
|
|
||||||
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
|
|
||||||
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-3.1-8B
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
dp_shard_size: 4
|
|
||||||
dp_replicate_size: 2
|
|
||||||
tensor_parallel_size: 2
|
|
||||||
# context_parallel_size: 2
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
||||||
reshard_after_forward: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
output_dir: ./outputs/ndp-out/
|
|
||||||
|
|
||||||
sequence_len: 2048
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 2
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: constant_with_warmup
|
|
||||||
learning_rate: 2e-6
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
base_model: Qwen/Qwen3-8B
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
dp_shard_size: 2
|
|
||||||
# dp_replicate_size: 1
|
|
||||||
context_parallel_size: 2
|
|
||||||
tensor_parallel_size: 2
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
state_dict_type: FULL_STATE_DICT
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
|
||||||
reshard_after_forward: true
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: tatsu-lab/alpaca
|
|
||||||
type: alpaca
|
|
||||||
|
|
||||||
output_dir: ./outputs/ndp-out/
|
|
||||||
|
|
||||||
sequence_len: 8192
|
|
||||||
sample_packing: true
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1 # must be 1 when using context parallel
|
|
||||||
num_epochs: 2
|
|
||||||
optimizer: adamw_torch_fused
|
|
||||||
lr_scheduler: constant_with_warmup
|
|
||||||
learning_rate: 2e-6
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
@@ -4,14 +4,17 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
# Finetune OpenAI's GPT-OSS with Axolotl
|
|
||||||
|
|
||||||
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
|
||||||
|
|
||||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
|
||||||
|
|
||||||
## Getting started
|
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
|
||||||
|
|
||||||
Here is an example of how to install from pip:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# LoRA SFT linear layers (1x48GB @ ~44GiB)
|
|
||||||
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
|
|
||||||
|
|
||||||
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
|
|
||||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
|
|
||||||
|
|
||||||
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
|
|
||||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
|
|
||||||
|
|
||||||
### Training 120B
|
|
||||||
|
|
||||||
On 8xH100s
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
|
||||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### Tool use
|
|
||||||
|
|
||||||
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
|
||||||
|
|
||||||
Here is an example dataset config:
|
|
||||||
```yaml
|
|
||||||
datasets:
|
|
||||||
- path: Nanobit/text-tools-2k-test
|
|
||||||
type: chat_template
|
|
||||||
```
|
|
||||||
|
|
||||||
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
|
|
||||||
|
|
||||||
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
|
|
||||||
|
|
||||||
### TIPS
|
|
||||||
|
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
|
||||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
|
||||||
|
|
||||||
## Optimization Guides
|
|
||||||
|
|
||||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
|
||||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
|
||||||
|
|
||||||
## Related Resources
|
|
||||||
|
|
||||||
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
|
|
||||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
|
||||||
- [Axolotl Website](https://axolotl.ai)
|
|
||||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
|
|
||||||
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
|
|
||||||
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
|
|
||||||
|
|
||||||
use_kernels: false
|
|
||||||
|
|
||||||
dp_shard_size: 16 # requires 2x8xH100 nodes
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/Multilingual-Thinking
|
|
||||||
type: chat_template
|
|
||||||
field_thinking: thinking
|
|
||||||
template_thinking_key: thinking
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
|
||||||
lr_scheduler: constant_with_warmup
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
flash_attention: true
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
activation_offloading: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.03
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
eot_tokens:
|
|
||||||
- "<|end|>"
|
|
||||||
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: true
|
|
||||||
state_dict_type: SHARDED_STATE_DICT
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
|
||||||
reshard_after_forward: true
|
|
||||||
cpu_ram_efficient_loading: true
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
base_model: openai/gpt-oss-20b
|
|
||||||
use_kernels: false
|
|
||||||
model_quantization_config: Mxfp4Config
|
|
||||||
model_quantization_config_kwargs:
|
|
||||||
dequantize: true
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/Multilingual-Thinking
|
|
||||||
type: chat_template
|
|
||||||
field_thinking: thinking
|
|
||||||
template_thinking_key: thinking
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: constant_with_warmup
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
flash_attention: true
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
activation_offloading: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.03
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
eot_tokens:
|
|
||||||
- "<|end|>"
|
|
||||||
|
|
||||||
# choose the zero3 configuration that best fits your system capabilities
|
|
||||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
base_model: openai/gpt-oss-20b
|
|
||||||
use_kernels: true
|
|
||||||
model_quantization_config: Mxfp4Config
|
|
||||||
model_quantization_config_kwargs:
|
|
||||||
dequantize: true
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/Multilingual-Thinking
|
|
||||||
type: chat_template
|
|
||||||
field_thinking: thinking
|
|
||||||
template_thinking_key: thinking
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
|
||||||
lr_scheduler: constant_with_warmup
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
flash_attention: true
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
activation_offloading: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.03
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
eot_tokens:
|
|
||||||
- "<|end|>"
|
|
||||||
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: true
|
|
||||||
state_dict_type: SHARDED_STATE_DICT
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
|
||||||
reshard_after_forward: true
|
|
||||||
# cpu_ram_efficient_loading: true
|
|
||||||
|
|
||||||
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
|
|
||||||
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
base_model: openai/gpt-oss-20b
|
|
||||||
use_kernels: false
|
|
||||||
model_quantization_config: Mxfp4Config
|
|
||||||
model_quantization_config_kwargs:
|
|
||||||
dequantize: true
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/Multilingual-Thinking
|
|
||||||
type: chat_template
|
|
||||||
field_thinking: thinking
|
|
||||||
template_thinking_key: thinking
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 2
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: constant_with_warmup
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
flash_attention: true
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
activation_offloading: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
|
|
||||||
warmup_ratio: 0.03
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
eot_tokens:
|
|
||||||
- "<|end|>"
|
|
||||||
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false
|
|
||||||
state_dict_type: SHARDED_STATE_DICT
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
|
||||||
reshard_after_forward: true
|
|
||||||
# cpu_ram_efficient_loading: true
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
base_model: openai/gpt-oss-20b
|
|
||||||
use_kernels: true
|
|
||||||
model_quantization_config: Mxfp4Config
|
|
||||||
model_quantization_config_kwargs:
|
|
||||||
dequantize: true
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: HuggingFaceH4/Multilingual-Thinking
|
|
||||||
type: chat_template
|
|
||||||
field_thinking: thinking
|
|
||||||
template_thinking_key: thinking
|
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0
|
|
||||||
output_dir: ./outputs/gpt-oss-out/
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_r: 8
|
|
||||||
lora_alpha: 16
|
|
||||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
|
||||||
lora_target_linear: true
|
|
||||||
|
|
||||||
# TODO: not supported for now, see peft#2710
|
|
||||||
#lora_target_parameters: # target the experts in the last two layers
|
|
||||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
|
||||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
|
||||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
|
||||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_name:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 8
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: constant_with_warmup
|
|
||||||
learning_rate: 2e-4
|
|
||||||
|
|
||||||
bf16: true
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
flash_attention: true
|
|
||||||
attn_implementation: kernels-community/vllm-flash-attn3
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
activation_offloading: true
|
|
||||||
|
|
||||||
logging_steps: 1
|
|
||||||
saves_per_epoch: 1
|
|
||||||
warmup_ratio: 0.1
|
|
||||||
|
|
||||||
special_tokens:
|
|
||||||
eot_tokens:
|
|
||||||
- "<|end|>"
|
|
||||||
@@ -45,6 +45,7 @@ logging_steps: 1
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
|
flash_attn_fuse_qkv: false
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ logging_steps: 1
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
|
flash_attn_fuse_qkv: false
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -8,14 +8,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Run the finetuning example:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ sequence_len: 2048
|
|||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
# SLURM Multi-Node Training
|
|
||||||
|
|
||||||
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
- Access to a SLURM cluster with GPU nodes
|
|
||||||
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### Standard SLURM Clusters
|
|
||||||
|
|
||||||
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
|
|
||||||
2. Place your Axolotl config file (`train.yaml`) in the same directory.
|
|
||||||
3. Set the appropriate environment variables for the job:
|
|
||||||
```bash
|
|
||||||
export HF_TOKEN="your-huggingface-token"
|
|
||||||
|
|
||||||
# metric tracking
|
|
||||||
# export WANDB_API_KEY="your-wandb-api-key"
|
|
||||||
# ...
|
|
||||||
```
|
|
||||||
4. Submit the job:
|
|
||||||
```bash
|
|
||||||
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
|
|
||||||
```
|
|
||||||
|
|
||||||
Where:
|
|
||||||
- `NUM_NODES`: Number of nodes to use
|
|
||||||
- `NUM_TRAINERS`: GPUs per node (typically 8)
|
|
||||||
- `PRIMARY_ADDR`: Hostname/IP of the master node
|
|
||||||
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
|
|
||||||
|
|
||||||
5. (Optional) Run other slurm commands:
|
|
||||||
```bash
|
|
||||||
# check job info
|
|
||||||
scontrol show job axolotl-cli
|
|
||||||
|
|
||||||
# check job queue
|
|
||||||
squeue
|
|
||||||
|
|
||||||
# check cluster status
|
|
||||||
sinfo
|
|
||||||
```
|
|
||||||
|
|
||||||
### RunPod Instant Clusters
|
|
||||||
|
|
||||||
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
|
|
||||||
|
|
||||||
1. **Deploy a SLURM Cluster**:
|
|
||||||
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
|
|
||||||
- Click "Create a Cluster"
|
|
||||||
- Choose your GPU type, node count, and region
|
|
||||||
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
|
|
||||||
- Deploy the cluster
|
|
||||||
|
|
||||||
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
|
|
||||||
|
|
||||||
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
|
|
||||||
|
|
||||||
## Additional Resources
|
|
||||||
|
|
||||||
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
|
||||||
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
|
|
||||||
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
|
|
||||||
# export HF_TOKEN="..."
|
|
||||||
# export WANDB_API_KEY="..."
|
|
||||||
#
|
|
||||||
|
|
||||||
# ---------- SBATCH commands ---------- #
|
|
||||||
#SBATCH --job-name=axolotl-slurm-multinode
|
|
||||||
#SBATCH --ntasks-per-node=1
|
|
||||||
#SBATCH --nodes=$NUM_NODES
|
|
||||||
#SBATCH --gpus-per-task=8
|
|
||||||
#SBATCH --cpus-per-task=128
|
|
||||||
|
|
||||||
export TORCH_DIST_INIT_BARRIER=0
|
|
||||||
|
|
||||||
srun axolotl preprocess train.yaml
|
|
||||||
|
|
||||||
srun axolotl train train.yaml --launcher torchrun -- \
|
|
||||||
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
|
|
||||||
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"
|
|
||||||
@@ -6,14 +6,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Please install the below.
|
2. Please install the below.
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.46.1
|
bitsandbytes==0.46.0
|
||||||
# triton 3.4.0 is not compatible with CCE
|
triton>=3.0.0
|
||||||
triton>=3.0.0,<3.4.0
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
@@ -13,21 +12,19 @@ liger-kernel==0.6.1
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft==0.17.0
|
peft==0.16.0
|
||||||
transformers==4.55.0
|
transformers==4.54.1
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.10.0
|
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.21.0
|
trl==0.20.0
|
||||||
hf_xet==1.1.5
|
hf_xet==1.1.5
|
||||||
kernels==0.9.0
|
|
||||||
trackio
|
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.41.1
|
gradio==5.23.3
|
||||||
|
|
||||||
modal==1.0.2
|
modal==1.0.2
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
@@ -69,6 +66,6 @@ torchao==0.12.0
|
|||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.5
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.8.3
|
mistral-common==1.8.3
|
||||||
|
|||||||
@@ -44,13 +44,8 @@ add_keys_to_authorized() {
|
|||||||
chmod 700 -R ~/.ssh
|
chmod 700 -R ~/.ssh
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set SSH port
|
|
||||||
if [ ! -z "$SSH_PORT" ]; then
|
|
||||||
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ $PUBLIC_KEY ]]; then
|
if [[ $PUBLIC_KEY ]]; then
|
||||||
# runpod, prime intellect
|
# runpod
|
||||||
add_keys_to_authorized "$PUBLIC_KEY"
|
add_keys_to_authorized "$PUBLIC_KEY"
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
@@ -81,13 +76,5 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
|
|||||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# start the runpod slurm init
|
|
||||||
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
|
|
||||||
|
|
||||||
if [[ -f "$SLURM_INIT" ]]; then
|
|
||||||
echo "[entrypoint] running $SLURM_INIT..."
|
|
||||||
bash "$SLURM_INIT"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
exec "$@"
|
exec "$@"
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.13.0.dev"
|
__version__ = "0.12.0.dev"
|
||||||
|
|||||||
@@ -153,14 +153,15 @@ def prepare_plugins(cfg: DictDefault):
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
for plugin_name in cfg["plugins"]:
|
for plugin_name in cfg["plugins"]:
|
||||||
plugin_manager.register(plugin_name)
|
plugin_manager.register(plugin_name)
|
||||||
for plugin in plugin_manager.plugins.values():
|
|
||||||
plugin.register(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def plugin_set_cfg(cfg: DictDefault):
|
def plugin_set_cfg(cfg: DictDefault):
|
||||||
if cfg.get("plugins"):
|
if cfg.get("plugins"):
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.cfg = cfg
|
plugin_manager.cfg = cfg
|
||||||
|
# now that we have the finalized cfg, register the plugins individually
|
||||||
|
for plugin in plugin_manager.plugins.values():
|
||||||
|
plugin.register(cfg)
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(
|
def load_cfg(
|
||||||
|
|||||||
@@ -123,10 +123,9 @@ def train(
|
|||||||
_launcher = None if kwargs.get("use_ray") else launcher
|
_launcher = None if kwargs.get("use_ray") else launcher
|
||||||
|
|
||||||
# Process each configuration
|
# Process each configuration
|
||||||
for cfg_file, is_group in generate_config_files(config, sweep):
|
for cfg_file in generate_config_files(config, sweep):
|
||||||
try:
|
try:
|
||||||
use_exec = is_group is not True
|
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
||||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||||
if not sweep:
|
if not sweep:
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Iterator, Literal
|
from typing import Any, Iterator, Literal
|
||||||
|
|
||||||
@@ -65,18 +64,10 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
|||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||||
"""
|
"""Generate list of configuration files to process."""
|
||||||
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
|
|
||||||
whether this is a group of configurations (i.e., a sweep).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Base configuration file
|
|
||||||
sweep: Sweep configuration file
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not sweep:
|
if not sweep:
|
||||||
yield config, False
|
yield config
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load sweep and base configurations
|
# Load sweep and base configurations
|
||||||
@@ -87,7 +78,6 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
|
|||||||
|
|
||||||
# Generate all possible configurations
|
# Generate all possible configurations
|
||||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||||
is_group = len(permutations) > 1
|
|
||||||
for permutation in permutations:
|
for permutation in permutations:
|
||||||
# pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
temp_file = tempfile.NamedTemporaryFile(
|
temp_file = tempfile.NamedTemporaryFile(
|
||||||
@@ -98,7 +88,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
|
|||||||
)
|
)
|
||||||
yaml.dump(permutation, temp_file)
|
yaml.dump(permutation, temp_file)
|
||||||
temp_file.close()
|
temp_file.close()
|
||||||
yield temp_file.name, is_group
|
yield temp_file.name
|
||||||
|
|
||||||
|
|
||||||
def launch_training(
|
def launch_training(
|
||||||
@@ -107,7 +97,6 @@ def launch_training(
|
|||||||
cloud: str | None,
|
cloud: str | None,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
launcher_args: list[str] | None = None,
|
launcher_args: list[str] | None = None,
|
||||||
use_exec: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training with the given configuration."""
|
"""Execute training with the given configuration."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -116,14 +105,11 @@ def launch_training(
|
|||||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||||
elif launcher:
|
elif launcher:
|
||||||
if launcher == "accelerate":
|
if launcher == "accelerate":
|
||||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||||
elif launcher == "torchrun":
|
elif launcher == "torchrun":
|
||||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||||
elif launcher == "python":
|
elif launcher == "python":
|
||||||
_launch_python_training(cfg_file, kwargs)
|
_launch_python_training(cfg_file, kwargs)
|
||||||
elif launcher is None:
|
|
||||||
# handle ray train launch
|
|
||||||
_launch_python_training(cfg_file, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _launch_cloud_training(
|
def _launch_cloud_training(
|
||||||
@@ -150,10 +136,7 @@ def _launch_cloud_training(
|
|||||||
|
|
||||||
|
|
||||||
def _launch_accelerate_training(
|
def _launch_accelerate_training(
|
||||||
cfg_file: str,
|
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||||
kwargs: dict,
|
|
||||||
launcher_args: list[str] | None = None,
|
|
||||||
use_exec: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via accelerate launcher."""
|
"""Execute training via accelerate launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -178,20 +161,11 @@ def _launch_accelerate_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
if use_exec:
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
# make sure to flush stdout and stderr before replacing the process
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stderr.flush()
|
|
||||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
|
||||||
else:
|
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
|
||||||
|
|
||||||
|
|
||||||
def _launch_torchrun_training(
|
def _launch_torchrun_training(
|
||||||
cfg_file: str,
|
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||||
kwargs: dict,
|
|
||||||
launcher_args: list[str] | None = None,
|
|
||||||
use_exec: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via torchrun launcher."""
|
"""Execute training via torchrun launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -204,13 +178,7 @@ def _launch_torchrun_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
if use_exec:
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
# make sure to flush stdout and stderr before replacing the process
|
|
||||||
sys.stdout.flush()
|
|
||||||
sys.stderr.flush()
|
|
||||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
|
||||||
else:
|
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
|
||||||
|
|
||||||
|
|
||||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
CLI to start the vllm server for online RL
|
CLI to start the vllm server for online RL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import trl
|
||||||
from trl.scripts.vllm_serve import ScriptArguments
|
from trl.scripts.vllm_serve import ScriptArguments
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
@@ -40,17 +42,13 @@ def do_vllm_serve(
|
|||||||
|
|
||||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||||
tensor_parallel_size = 1
|
|
||||||
data_parallel_size = 1
|
|
||||||
|
|
||||||
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
|
tensor_parallel_size = (
|
||||||
tensor_parallel_size = (
|
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
)
|
||||||
)
|
data_parallel_size = (
|
||||||
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
|
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||||
data_parallel_size = (
|
)
|
||||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
|
||||||
)
|
|
||||||
host = cli_args.get("host") or cfg.vllm.host
|
host = cli_args.get("host") or cfg.vllm.host
|
||||||
port = cli_args.get("port") or cfg.vllm.port
|
port = cli_args.get("port") or cfg.vllm.port
|
||||||
gpu_memory_utilization = (
|
gpu_memory_utilization = (
|
||||||
@@ -83,3 +81,63 @@ def do_vllm_serve(
|
|||||||
enable_reasoning=enable_reasoning,
|
enable_reasoning=enable_reasoning,
|
||||||
)
|
)
|
||||||
vllm_serve_main(vllm_script_args)
|
vllm_serve_main(vllm_script_args)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_vllm_worker():
|
||||||
|
from multiprocessing.connection import Connection
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
def llm_worker(
|
||||||
|
script_args: AxolotlScriptArguments,
|
||||||
|
data_parallel_rank: int,
|
||||||
|
master_port: int,
|
||||||
|
connection: Connection,
|
||||||
|
) -> None:
|
||||||
|
# Set required environment variables for DP to work with vLLM
|
||||||
|
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||||
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||||
|
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||||
|
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=script_args.model,
|
||||||
|
revision=script_args.revision,
|
||||||
|
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||||
|
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||||
|
enforce_eager=script_args.enforce_eager,
|
||||||
|
dtype=script_args.dtype,
|
||||||
|
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||||
|
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||||
|
# This is particularly useful here because we generate completions from the same prompts.
|
||||||
|
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||||
|
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||||
|
max_model_len=script_args.max_model_len,
|
||||||
|
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||||
|
enable_reasoning=script_args.enable_reasoning,
|
||||||
|
reasoning_parser=script_args.reasoning_parser,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send ready signal to parent process
|
||||||
|
connection.send({"status": "ready"})
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Wait for commands from the parent process
|
||||||
|
try:
|
||||||
|
command = connection.recv()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
llm.collective_rpc(method="close_communicator")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Handle commands
|
||||||
|
if command["type"] in ["call", "fire_and_forget"]:
|
||||||
|
method_name = command["method"]
|
||||||
|
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
||||||
|
method = getattr(llm, method_name)
|
||||||
|
result = method(*args, **kwargs)
|
||||||
|
if command["type"] == "call":
|
||||||
|
connection.send(result)
|
||||||
|
elif command["type"] == "shutdown":
|
||||||
|
break
|
||||||
|
|
||||||
|
trl.scripts.vllm_serve.llm_worker = llm_worker
|
||||||
|
|||||||
@@ -13,5 +13,4 @@ MOE_ARCH_BLOCK = {
|
|||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
"gpt_oss": "GptOssDecoderLayer",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,10 +24,12 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import PartialState
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||||
@@ -38,7 +40,6 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
from axolotl.utils.distributed import build_parallelism_config
|
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -266,24 +267,27 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
optimizer_cls = MuonOptimizerFactory
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "dion":
|
|
||||||
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
|
|
||||||
DionOptimizerFactory,
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer_cls = DionOptimizerFactory
|
|
||||||
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
|
||||||
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
_, device_mesh = build_parallelism_config(self.cfg)
|
|
||||||
if device_mesh is not None:
|
|
||||||
optimizer_kwargs["device_mesh"] = device_mesh
|
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
optimizer_kwargs["foreach"] = False
|
optimizer_kwargs["foreach"] = False
|
||||||
optimizer_cls = AdamW
|
optimizer_cls = AdamW
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||||
|
# TODO remove 20250401
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW4bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||||
|
)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW8bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
@@ -429,12 +433,30 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||||
|
partial_state = PartialState()
|
||||||
|
has_pc_attr = (
|
||||||
|
hasattr(partial_state, "parallelism_config")
|
||||||
|
and partial_state.parallelism_config
|
||||||
|
)
|
||||||
|
has_pc_key = (
|
||||||
|
"parallelism_config"
|
||||||
|
in partial_state._shared_state # pylint: disable=protected-access
|
||||||
|
and partial_state._shared_state[ # pylint: disable=protected-access
|
||||||
|
"parallelism_config"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
use_configured_state = has_pc_attr or has_pc_key
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
|
use_configured_state = self.cfg.accelerator_config.pop(
|
||||||
|
"use_configured_state", use_configured_state
|
||||||
|
)
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
**self.cfg.accelerator_config
|
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
|
use_configured_state=use_configured_state,
|
||||||
|
)
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
@@ -494,20 +516,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
"include_tokens_per_second",
|
"include_tokens_per_second",
|
||||||
"weight_decay",
|
"weight_decay",
|
||||||
"seed",
|
"seed",
|
||||||
"dion_momentum",
|
|
||||||
"dion_rank_fraction",
|
|
||||||
"dion_rank_multiple_of",
|
|
||||||
]:
|
]:
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
|
|
||||||
arg_map = {
|
|
||||||
"dion_learning_rate": "dion_lr",
|
|
||||||
}
|
|
||||||
for kwarg, cfg_arg in arg_map.items():
|
|
||||||
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
|
||||||
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
|
|
||||||
|
|
||||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||||
training_args_kwargs["average_tokens_across_devices"] = False
|
training_args_kwargs["average_tokens_across_devices"] = False
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,6 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -137,18 +136,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlRewardTrainer
|
return AxolotlRewardTrainer
|
||||||
if self.cfg.process_reward_model:
|
if self.cfg.process_reward_model:
|
||||||
return AxolotlPRMTrainer
|
return AxolotlPRMTrainer
|
||||||
|
|
||||||
if self.cfg.trainer_cls:
|
|
||||||
# override the trainer cls
|
|
||||||
try:
|
|
||||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
|
||||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
|
||||||
return trainer_cls
|
|
||||||
except (ImportError, AttributeError, ValueError) as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -363,7 +350,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||||
self.cfg.sequence_len / multiple
|
self.cfg.sequence_len / multiple
|
||||||
)
|
)
|
||||||
elif self.cfg.pad_to_sequence_len is None:
|
else:
|
||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from axolotl.core.trainers.grpo import GRPOStrategy
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import ensure_dtype
|
from axolotl.loaders.utils import ensure_dtype
|
||||||
from axolotl.utils.callbacks.qat import QATCallback
|
from axolotl.utils.callbacks.qat import QATCallback
|
||||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
@@ -73,16 +72,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
if self.cfg.trainer_cls:
|
|
||||||
# override the trainer cls
|
|
||||||
try:
|
|
||||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
|
||||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
|
||||||
except (ImportError, AttributeError, ValueError) as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
return trainer_cls, trainer_cls_args
|
return trainer_cls, trainer_cls_args
|
||||||
|
|
||||||
def _build_training_arguments(self, total_num_steps):
|
def _build_training_arguments(self, total_num_steps):
|
||||||
|
|||||||
@@ -10,11 +10,8 @@ from functools import partial, wraps
|
|||||||
from typing import Any, Callable, Literal, Optional
|
from typing import Any, Callable, Literal, Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import safetensors
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.state import AcceleratorState
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import PeftModel
|
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
DataLoader,
|
DataLoader,
|
||||||
@@ -22,10 +19,8 @@ from torch.utils.data import (
|
|||||||
Sampler,
|
Sampler,
|
||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
from transformers import PreTrainedModel, Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -520,18 +515,7 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
res = super().create_accelerator_and_postprocess()
|
||||||
accelerator_config = self.args.accelerator_config.to_dict()
|
|
||||||
use_configured_state = accelerator_config.get("use_configured_state", False)
|
|
||||||
if not use_configured_state:
|
|
||||||
AcceleratorState._reset_state( # pylint: disable=protected-access
|
|
||||||
reset_partial_state=True
|
|
||||||
)
|
|
||||||
|
|
||||||
super().create_accelerator_and_postprocess()
|
|
||||||
|
|
||||||
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
|
||||||
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if (
|
if (
|
||||||
@@ -540,6 +524,8 @@ class AxolotlTrainer(
|
|||||||
):
|
):
|
||||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def additional_accelerator_args(
|
def additional_accelerator_args(
|
||||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||||
@@ -581,10 +567,10 @@ class AxolotlTrainer(
|
|||||||
# Add memory usage
|
# Add memory usage
|
||||||
try:
|
try:
|
||||||
active, allocated, reserved = get_gpu_memory_usage()
|
active, allocated, reserved = get_gpu_memory_usage()
|
||||||
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
logs["memory/max_memory_active"] = active
|
||||||
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
logs["memory/max_memory_allocated"] = allocated
|
||||||
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
logs["memory/device_memory_reserved"] = reserved
|
||||||
except (ValueError, TypeError, FileNotFoundError):
|
except (ValueError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
@@ -604,64 +590,3 @@ class AxolotlTrainer(
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
|
||||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
|
||||||
supported_classes = (
|
|
||||||
(PreTrainedModel,)
|
|
||||||
if not is_peft_available()
|
|
||||||
else (PreTrainedModel, PeftModel)
|
|
||||||
)
|
|
||||||
# Save a trained model and configuration using `save_pretrained()`.
|
|
||||||
# They can then be reloaded using `from_pretrained()`
|
|
||||||
if not isinstance(self.model, supported_classes):
|
|
||||||
if state_dict is None:
|
|
||||||
state_dict = self.model.state_dict()
|
|
||||||
if isinstance(
|
|
||||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
|
||||||
supported_classes,
|
|
||||||
):
|
|
||||||
self.accelerator.unwrap_model(
|
|
||||||
self.model, keep_torch_compile=False
|
|
||||||
).save_pretrained(
|
|
||||||
output_dir,
|
|
||||||
state_dict=state_dict,
|
|
||||||
safe_serialization=self.args.save_safetensors,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
LOG.info(
|
|
||||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
|
||||||
)
|
|
||||||
if self.args.save_safetensors:
|
|
||||||
safetensors.torch.save_file(
|
|
||||||
state_dict,
|
|
||||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
|
||||||
metadata={"format": "pt"},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
|
||||||
else:
|
|
||||||
self.model.save_pretrained(
|
|
||||||
output_dir,
|
|
||||||
state_dict=state_dict,
|
|
||||||
safe_serialization=self.args.save_safetensors,
|
|
||||||
is_main_process=self.accelerator.is_main_process,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.processing_class is not None:
|
|
||||||
self.processing_class.save_pretrained(output_dir)
|
|
||||||
elif (
|
|
||||||
self.data_collator is not None
|
|
||||||
and hasattr(self.data_collator, "tokenizer")
|
|
||||||
and self.data_collator.tokenizer is not None
|
|
||||||
):
|
|
||||||
LOG.info(
|
|
||||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
|
||||||
)
|
|
||||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
|
||||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
Mixin for correctly saving fsdp
|
Mixin for correctly saving fsdp
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from accelerate import PartialState
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|
||||||
@@ -19,15 +18,3 @@ class DistributedParallelMixin(Trainer):
|
|||||||
):
|
):
|
||||||
state_dict = self.accelerator.get_state_dict(self.model)
|
state_dict = self.accelerator.get_state_dict(self.model)
|
||||||
super()._save(output_dir, state_dict=state_dict)
|
super()._save(output_dir, state_dict=state_dict)
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
|
||||||
super().create_accelerator_and_postprocess()
|
|
||||||
if (
|
|
||||||
self.accelerator.distributed_type == "FSDP"
|
|
||||||
and self.accelerator.state.fsdp_plugin is None
|
|
||||||
):
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
# handle Context Parallelism without FSDP
|
|
||||||
self.accelerator.state.distributed_type = "MULTI_GPU"
|
|
||||||
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
|
||||||
PartialState().distributed_type = "MULTI_GPU"
|
|
||||||
|
|||||||
@@ -243,18 +243,3 @@ class AxolotlTrainingMixins:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# end of multi-modal section
|
# end of multi-modal section
|
||||||
|
|
||||||
dion_learning_rate: float | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The learning rate for Dion"},
|
|
||||||
)
|
|
||||||
dion_momentum: float | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The momentum for Dion"},
|
|
||||||
)
|
|
||||||
dion_rank_fraction: float | None = field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
dion_rank_multiple_of: int | None = field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -26,11 +26,9 @@ import traceback
|
|||||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||||
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from torch import nn
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from transformers import PreTrainedModel, Trainer
|
from transformers import PreTrainedModel, Trainer
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -76,8 +74,8 @@ class BasePlugin:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initializes the BasePlugin."""
|
"""Initializes the BasePlugin."""
|
||||||
|
|
||||||
def register(self, cfg: dict): # pylint: disable=unused-argument
|
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||||
"""Registers the plugin with the given configuration as an unparsed dict.
|
"""Registers the plugin with the given configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: The configuration for the plugin.
|
cfg: The configuration for the plugin.
|
||||||
@@ -643,24 +641,3 @@ class BaseOptimizerFactory:
|
|||||||
self, opt_model, training_args, **optimizer_kwargs
|
self, opt_model, training_args, **optimizer_kwargs
|
||||||
) -> Optimizer | None:
|
) -> Optimizer | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# duplicated from transformers
|
|
||||||
def get_decay_parameter_names(self, model) -> list[str]:
|
|
||||||
"""
|
|
||||||
Get all parameter names that weight decay will be applied to.
|
|
||||||
|
|
||||||
This function filters out parameters in two ways:
|
|
||||||
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
|
|
||||||
2. By parameter name patterns (containing 'bias', or variation of 'norm')
|
|
||||||
"""
|
|
||||||
forbidden_name_patterns = [
|
|
||||||
r"bias",
|
|
||||||
r"layernorm",
|
|
||||||
r"rmsnorm",
|
|
||||||
r"(?:^|\.)norm(?:$|\.)",
|
|
||||||
r"_norm(?:$|\.)",
|
|
||||||
]
|
|
||||||
decay_parameters = get_parameter_names(
|
|
||||||
model, [nn.LayerNorm], forbidden_name_patterns
|
|
||||||
)
|
|
||||||
return decay_parameters
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -31,7 +31,6 @@ plugins:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- arcee
|
|
||||||
- cohere
|
- cohere
|
||||||
- cohere2
|
- cohere2
|
||||||
- gemma
|
- gemma
|
||||||
@@ -42,17 +41,13 @@ plugins:
|
|||||||
- gemma3n_text
|
- gemma3n_text
|
||||||
- glm
|
- glm
|
||||||
- glm4
|
- glm4
|
||||||
- gpt_oss
|
|
||||||
- granite
|
- granite
|
||||||
- granitemoe
|
- granitemoe
|
||||||
- hunyuan_v1_dense
|
|
||||||
- hunyuan_v1_moe
|
|
||||||
- llama
|
- llama
|
||||||
- llama4
|
- llama4
|
||||||
- llama4_text
|
- llama4_text
|
||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
- mixtral
|
|
||||||
- mllama
|
- mllama
|
||||||
- phi
|
- phi
|
||||||
- phi3
|
- phi3
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -284,12 +284,12 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
def _tokenize_single_prompt(self, prompt):
|
def _tokenize_single_prompt(self, prompt):
|
||||||
target_token_ids = prompt.get("target_token_ids", None)
|
logprobs = prompt.pop(self.logprobs_field)
|
||||||
|
target_token_ids = prompt.pop("target_token_ids")
|
||||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||||
|
tokenized_prompt[self.logprobs_field] = logprobs
|
||||||
if target_token_ids is not None:
|
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||||
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from typing import Callable
|
|||||||
import torch
|
import torch
|
||||||
from bitsandbytes.functional import QuantState
|
from bitsandbytes.functional import QuantState
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed.tensor import DTensor
|
|
||||||
|
|
||||||
from .geglu import geglu_backward, geglu_forward
|
from .geglu import geglu_backward, geglu_forward
|
||||||
from .quantize import dequantize
|
from .quantize import dequantize
|
||||||
@@ -26,7 +25,6 @@ def get_lora_parameters(
|
|||||||
proj: nn.Module,
|
proj: nn.Module,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor | None,
|
|
||||||
QuantState | None,
|
QuantState | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
@@ -39,54 +37,39 @@ def get_lora_parameters(
|
|||||||
proj: The projection module to extract parameters from.
|
proj: The projection module to extract parameters from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing the base weights, quantization state, LoRA A and B weights,
|
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
||||||
scaling factor, and base layer bias. Quant state, weights, and bias may be
|
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
||||||
`None` if not available.
|
available.
|
||||||
"""
|
"""
|
||||||
# For DPO or disabled adapters
|
# For DPO or disabled adapters
|
||||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||||
W = base_layer.weight
|
W = base_layer.weight
|
||||||
b = base_layer.bias
|
|
||||||
|
|
||||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||||
quant_state = getattr(W, "quant_state", None)
|
quant_state = getattr(W, "quant_state", None)
|
||||||
return W, b, quant_state, None, None, None
|
return W, quant_state, None, None, None
|
||||||
|
|
||||||
quant_state = getattr(W, "quant_state", None)
|
|
||||||
|
|
||||||
active_adapter = (
|
active_adapter = (
|
||||||
proj.active_adapters[0]
|
proj.active_adapters[0]
|
||||||
if hasattr(proj, "active_adapters")
|
if hasattr(proj, "active_adapters")
|
||||||
else proj.active_adapter
|
else proj.active_adapter
|
||||||
)
|
)
|
||||||
|
A = proj.lora_A[active_adapter].weight
|
||||||
linear_A = proj.lora_A[active_adapter]
|
B = proj.lora_B[active_adapter].weight
|
||||||
linear_B = proj.lora_B[active_adapter]
|
|
||||||
|
|
||||||
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
|
|
||||||
# We fuse linear layers + LoRA adapters calculations into a single
|
|
||||||
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
|
|
||||||
# Note that we don't apply resharding later in this module (it gets messy quickly),
|
|
||||||
# but LoRA parameters are generally small enough that this is not an issue.
|
|
||||||
if isinstance(linear_A.weight, DTensor):
|
|
||||||
linear_A.unshard()
|
|
||||||
linear_B.unshard()
|
|
||||||
|
|
||||||
A = linear_A.weight
|
|
||||||
B = linear_B.weight
|
|
||||||
s = proj.scaling[active_adapter]
|
s = proj.scaling[active_adapter]
|
||||||
|
|
||||||
return W, b, quant_state, A, B, s
|
quant_state = getattr(W, "quant_state", None)
|
||||||
|
|
||||||
|
return W, quant_state, A, B, s
|
||||||
|
|
||||||
|
|
||||||
def matmul_lora(
|
def matmul_lora(
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
b: torch.Tensor | None,
|
W_quant: QuantState,
|
||||||
W_quant: QuantState | None,
|
A: torch.Tensor,
|
||||||
A: torch.Tensor | None,
|
B: torch.Tensor,
|
||||||
B: torch.Tensor | None,
|
s: float,
|
||||||
s: float | None,
|
|
||||||
out: torch.Tensor | None = None,
|
out: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -107,22 +90,20 @@ def matmul_lora(
|
|||||||
dtype = X.dtype
|
dtype = X.dtype
|
||||||
W = dequantize(W.t(), W_quant)
|
W = dequantize(W.t(), W_quant)
|
||||||
|
|
||||||
reshape = False
|
|
||||||
if X.dim() == 3:
|
if X.dim() == 3:
|
||||||
batch, seq_len, _ = X.shape
|
batch, seq_len, _ = X.shape
|
||||||
X = X.view(-1, X.shape[-1])
|
X = X.view(-1, X.shape[-1])
|
||||||
reshape = True
|
reshape = True
|
||||||
|
else:
|
||||||
|
reshape = False
|
||||||
|
|
||||||
out = torch.matmul(X, W, out=out)
|
out = torch.matmul(X, W, out=out)
|
||||||
if W_quant is not None:
|
if W_quant is not None:
|
||||||
del W
|
del W
|
||||||
|
|
||||||
if A is not None:
|
if A is not None:
|
||||||
A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
|
A, B = A.t(), B.t()
|
||||||
out += s * X @ A @ B
|
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||||
|
|
||||||
if b is not None:
|
|
||||||
out += b
|
|
||||||
|
|
||||||
return out.view(batch, seq_len, -1) if reshape else out
|
return out.view(batch, seq_len, -1) if reshape else out
|
||||||
|
|
||||||
@@ -136,20 +117,17 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
ctx,
|
ctx,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
gate_weight: torch.Tensor,
|
gate_weight: torch.Tensor,
|
||||||
gate_bias: torch.Tensor | None,
|
gate_quant: object | None,
|
||||||
gate_quant: QuantState | None,
|
|
||||||
gate_A: torch.Tensor | None,
|
gate_A: torch.Tensor | None,
|
||||||
gate_B: torch.Tensor | None,
|
gate_B: torch.Tensor | None,
|
||||||
gate_scale: float,
|
gate_scale: float,
|
||||||
up_weight: torch.Tensor,
|
up_weight: torch.Tensor,
|
||||||
up_bias: torch.Tensor | None,
|
up_quant: object | None,
|
||||||
up_quant: QuantState | None,
|
|
||||||
up_A: torch.Tensor | None,
|
up_A: torch.Tensor | None,
|
||||||
up_B: torch.Tensor | None,
|
up_B: torch.Tensor | None,
|
||||||
up_scale: float,
|
up_scale: float,
|
||||||
down_weight: torch.Tensor,
|
down_weight: torch.Tensor,
|
||||||
down_bias: torch.Tensor | None,
|
down_quant: object | None,
|
||||||
down_quant: QuantState | None,
|
|
||||||
down_A: torch.Tensor | None,
|
down_A: torch.Tensor | None,
|
||||||
down_B: torch.Tensor | None,
|
down_B: torch.Tensor | None,
|
||||||
down_scale: float,
|
down_scale: float,
|
||||||
@@ -164,22 +142,20 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
ctx: Autograd context
|
ctx: Autograd context
|
||||||
X: Input features
|
X: Input features
|
||||||
gate_weight: Gate projection weight
|
gate_weight: Gate projection weight
|
||||||
gate_bias: Gate projection bias
|
|
||||||
gate_quant: Gate quantization state
|
gate_quant: Gate quantization state
|
||||||
gate_A: Gate LoRA A matrix
|
gate_A: Gate LoRA A matrix
|
||||||
gate_B: Gate LoRA B matrix
|
gate_B: Gate LoRA B matrix
|
||||||
gate_scale: Gate LoRA scale
|
gate_scale: Gate LoRA scale
|
||||||
up_weight: Up projection weight
|
up_weight: Up-projection weight
|
||||||
up_quant: Up projection quantization state
|
up_quant: Up-projection quantization state
|
||||||
up_A: Up projection LoRA A matrix
|
up_A: Up-projection LoRA A matrix
|
||||||
up_B: Up projection LoRA B matrix
|
up_B: Up-projection LoRA B matrix
|
||||||
up_scale: Up projection LoRA scale
|
up_scale: Up-projection LoRA scale
|
||||||
down_weight: Down projection weight
|
down_weight: Down-projection weight
|
||||||
down_bias: Down projection bias
|
down_quant: Down-projection quantization state
|
||||||
down_quant: Down projection quantization state
|
down_A: Down-projection LoRA A matrix
|
||||||
down_A: Down projection LoRA A matrix
|
down_B: Down-projection LoRA B matrix
|
||||||
down_B: Down projection LoRA B matrix
|
down_scale: Down-projection LoRA scale
|
||||||
down_scale: Down projection LoRA scale
|
|
||||||
activation_fn: Forward activation function
|
activation_fn: Forward activation function
|
||||||
activation_fn_backward: Backward activation function
|
activation_fn_backward: Backward activation function
|
||||||
inplace: Whether to perform operations in-place
|
inplace: Whether to perform operations in-place
|
||||||
@@ -188,17 +164,15 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
Output transformed by multi-layer perceptron and activation function
|
Output transformed by multi-layer perceptron and activation function
|
||||||
"""
|
"""
|
||||||
# Compute projections
|
# Compute projections
|
||||||
gate = matmul_lora(
|
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
||||||
X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
|
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
||||||
)
|
|
||||||
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
|
|
||||||
|
|
||||||
# Activation
|
# Activation
|
||||||
hidden = activation_fn(gate, up)
|
hidden = activation_fn(gate, up)
|
||||||
|
|
||||||
# Down projection
|
# Down projection
|
||||||
output = matmul_lora(
|
output = matmul_lora(
|
||||||
hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
|
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save for backward
|
# Save for backward
|
||||||
@@ -221,26 +195,22 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Performs backward pass computation for LoRA MLP.
|
Performs backward pass computation for LoRA MLP.
|
||||||
@@ -252,7 +222,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple containing gradients for all inputs from forward pass:
|
Tuple containing gradients for all inputs from forward pass:
|
||||||
- Input gradient tensor (or `None`)
|
- Input gradient tensor (or `None`)
|
||||||
- `None` for weights/biases/quantization states
|
- `None` for weights/quantization states
|
||||||
- LoRA A/B matrix gradients (or `None`)
|
- LoRA A/B matrix gradients (or `None`)
|
||||||
- `None` for scaling factors
|
- `None` for scaling factors
|
||||||
- `None` for activation functions and flags
|
- `None` for activation functions and flags
|
||||||
@@ -295,10 +265,9 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
dtype = X.dtype
|
dtype = X.dtype
|
||||||
|
|
||||||
# Down projection
|
# Down projection
|
||||||
grad_down = matmul_lora(
|
DW = matmul_lora(
|
||||||
grad_output,
|
grad_output,
|
||||||
down_weight.t(),
|
down_weight.t(),
|
||||||
None,
|
|
||||||
down_quant,
|
down_quant,
|
||||||
down_B,
|
down_B,
|
||||||
down_A,
|
down_A,
|
||||||
@@ -306,7 +275,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Activation backward
|
# Activation backward
|
||||||
h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
|
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
||||||
|
|
||||||
# Initialize and compute LoRA gradients
|
# Initialize and compute LoRA gradients
|
||||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||||
@@ -346,8 +315,8 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||||
|
|
||||||
# Gate projection gradients
|
# Gate projection gradients
|
||||||
gate_weight = dequantize(gate_weight, gate_quant)
|
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
||||||
dX += grad_gate @ gate_weight
|
dX += grad_gate @ gate_weight.t()
|
||||||
del gate_weight
|
del gate_weight
|
||||||
|
|
||||||
if gate_A is not None and gate_B is not None:
|
if gate_A is not None and gate_B is not None:
|
||||||
@@ -365,26 +334,22 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
dX,
|
dX,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
d_gate_A.t() if d_gate_A is not None else None,
|
d_gate_A.t() if d_gate_A is not None else None,
|
||||||
d_gate_B.t() if d_gate_B is not None else None,
|
d_gate_B.t() if d_gate_B is not None else None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
d_up_A.t() if d_up_A is not None else None,
|
d_up_A.t() if d_up_A is not None else None,
|
||||||
d_up_B.t() if d_up_B is not None else None,
|
d_up_B.t() if d_up_B is not None else None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
d_down_A.t() if d_down_A is not None else None,
|
d_down_A.t() if d_down_A is not None else None,
|
||||||
d_down_B.t() if d_down_B is not None else None,
|
d_down_B.t() if d_down_B is not None else None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -399,26 +364,23 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
|
|||||||
Returns:
|
Returns:
|
||||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||||
"""
|
"""
|
||||||
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||||
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||||
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||||
|
|
||||||
out = LoRA_MLP.apply(
|
out = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
gateW,
|
gateW,
|
||||||
gateb,
|
|
||||||
gateW_quant,
|
gateW_quant,
|
||||||
gateA,
|
gateA,
|
||||||
gateB,
|
gateB,
|
||||||
gateS,
|
gateS,
|
||||||
upW,
|
upW,
|
||||||
upb,
|
|
||||||
upW_quant,
|
upW_quant,
|
||||||
upA,
|
upA,
|
||||||
upB,
|
upB,
|
||||||
upS,
|
upS,
|
||||||
downW,
|
downW,
|
||||||
downb,
|
|
||||||
downW_quant,
|
downW_quant,
|
||||||
downA,
|
downA,
|
||||||
downB,
|
downB,
|
||||||
@@ -442,25 +404,22 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
|
|||||||
Returns:
|
Returns:
|
||||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||||
"""
|
"""
|
||||||
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||||
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||||
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||||
out = LoRA_MLP.apply(
|
out = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
gateW,
|
gateW,
|
||||||
gateb,
|
|
||||||
gateW_quant,
|
gateW_quant,
|
||||||
gateA,
|
gateA,
|
||||||
gateB,
|
gateB,
|
||||||
gateS,
|
gateS,
|
||||||
upW,
|
upW,
|
||||||
upb,
|
|
||||||
upW_quant,
|
upW_quant,
|
||||||
upA,
|
upA,
|
||||||
upB,
|
upB,
|
||||||
upS,
|
upS,
|
||||||
downW,
|
downW,
|
||||||
downb,
|
|
||||||
downW_quant,
|
downW_quant,
|
||||||
downA,
|
downA,
|
||||||
downB,
|
downB,
|
||||||
@@ -487,19 +446,16 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
q_weight: torch.Tensor,
|
q_weight: torch.Tensor,
|
||||||
q_bias: torch.Tensor | None,
|
|
||||||
q_quant: QuantState | None,
|
q_quant: QuantState | None,
|
||||||
q_A: torch.Tensor | None,
|
q_A: torch.Tensor | None,
|
||||||
q_B: torch.Tensor | None,
|
q_B: torch.Tensor | None,
|
||||||
q_scale: float,
|
q_scale: float,
|
||||||
k_weight: torch.Tensor,
|
k_weight: torch.Tensor,
|
||||||
k_bias: torch.Tensor | None,
|
|
||||||
k_quant: QuantState | None,
|
k_quant: QuantState | None,
|
||||||
k_A: torch.Tensor | None,
|
k_A: torch.Tensor | None,
|
||||||
k_B: torch.Tensor | None,
|
k_B: torch.Tensor | None,
|
||||||
k_scale: float,
|
k_scale: float,
|
||||||
v_weight: torch.Tensor,
|
v_weight: torch.Tensor,
|
||||||
v_bias: torch.Tensor | None,
|
|
||||||
v_quant: QuantState | None,
|
v_quant: QuantState | None,
|
||||||
v_A: torch.Tensor | None,
|
v_A: torch.Tensor | None,
|
||||||
v_B: torch.Tensor | None,
|
v_B: torch.Tensor | None,
|
||||||
@@ -513,19 +469,16 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
ctx: Autograd context
|
ctx: Autograd context
|
||||||
X: Input tensor
|
X: Input tensor
|
||||||
q_weight: Query projection weight
|
q_weight: Query projection weight
|
||||||
q_bias: Query projection bias
|
|
||||||
q_quant: Query quantization state
|
q_quant: Query quantization state
|
||||||
q_A: Query LoRA A matrix
|
q_A: Query LoRA A matrix
|
||||||
q_B: Query LoRA B matrix
|
q_B: Query LoRA B matrix
|
||||||
q_scale: Query LoRA scale
|
q_scale: Query LoRA scale
|
||||||
k_weight: Key projection weight
|
k_weight: Key projection weight
|
||||||
k_bias: Key projection bias
|
|
||||||
k_quant: Key quantization state
|
k_quant: Key quantization state
|
||||||
k_A: Key LoRA A matrix
|
k_A: Key LoRA A matrix
|
||||||
k_B: Key LoRA B matrix
|
k_B: Key LoRA B matrix
|
||||||
k_scale: Key LoRA scale
|
k_scale: Key LoRA scale
|
||||||
v_weight: Value projection weight
|
v_weight: Value projection weight
|
||||||
v_bias: Value projection bias
|
|
||||||
v_quant: Value quantization state
|
v_quant: Value quantization state
|
||||||
v_A: Value LoRA A matrix
|
v_A: Value LoRA A matrix
|
||||||
v_B: Value LoRA B matrix
|
v_B: Value LoRA B matrix
|
||||||
@@ -535,21 +488,20 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (Query, Key, Value) projection tensors
|
Tuple of (Query, Key, Value) projection tensors
|
||||||
"""
|
"""
|
||||||
Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
|
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
||||||
K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
|
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
||||||
V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
|
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
||||||
|
|
||||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||||
ctx.scales = (q_scale, k_scale, v_scale)
|
ctx.scales = (q_scale, k_scale, v_scale)
|
||||||
ctx.quants = (q_quant, k_quant, v_quant)
|
ctx.quants = (q_quant, k_quant, v_quant)
|
||||||
ctx.weights = (q_weight, k_weight, v_weight)
|
ctx.weights = (q_weight, k_weight, v_weight)
|
||||||
ctx.biases = (q_bias, k_bias, v_bias)
|
|
||||||
ctx.inplace = inplace
|
ctx.inplace = inplace
|
||||||
|
|
||||||
return Q, K, V
|
return Q, K, V
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch_amp_custom_bwd
|
@torch_amp_custom_fwd
|
||||||
def backward(
|
def backward(
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
q_grad: torch.Tensor,
|
q_grad: torch.Tensor,
|
||||||
@@ -559,19 +511,16 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
@@ -659,31 +608,31 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
# Transpose gradients if needed
|
# Transpose gradients if needed
|
||||||
if d_A_q is not None:
|
if d_A_q is not None:
|
||||||
d_A_q = d_A_q.t()
|
d_A_q = d_A_q.t()
|
||||||
d_B_q = d_B_q.t() # type: ignore[union-attr]
|
if d_B_q is not None:
|
||||||
|
d_B_q = d_B_q.t()
|
||||||
if d_A_k is not None:
|
if d_A_k is not None:
|
||||||
d_A_k = d_A_k.t()
|
d_A_k = d_A_k.t()
|
||||||
d_B_k = d_B_k.t() # type: ignore[union-attr]
|
if d_B_k is not None:
|
||||||
|
d_B_k = d_B_k.t()
|
||||||
if d_A_v is not None:
|
if d_A_v is not None:
|
||||||
d_A_v = d_A_v.t()
|
d_A_v = d_A_v.t()
|
||||||
d_B_v = d_B_v.t() # type: ignore[union-attr]
|
if d_B_v is not None:
|
||||||
|
d_B_v = d_B_v.t()
|
||||||
|
|
||||||
return (
|
return (
|
||||||
grad_X.view(batch, seq_len, -1),
|
grad_X.view(batch, seq_len, -1),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
d_A_q,
|
d_A_q,
|
||||||
d_B_q,
|
d_B_q,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
d_A_k,
|
d_A_k,
|
||||||
d_B_k,
|
d_B_k,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
d_A_v,
|
d_A_v,
|
||||||
d_B_v,
|
d_B_v,
|
||||||
None,
|
None,
|
||||||
@@ -704,25 +653,22 @@ def apply_lora_qkv(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (Query, Key, Value) projection tensors
|
Tuple of (Query, Key, Value) projection tensors
|
||||||
"""
|
"""
|
||||||
QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||||
KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||||
VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||||
Q, K, V = LoRA_QKV.apply(
|
Q, K, V = LoRA_QKV.apply(
|
||||||
X,
|
X,
|
||||||
QW,
|
QW,
|
||||||
Qb,
|
|
||||||
QW_quant,
|
QW_quant,
|
||||||
QA,
|
QA,
|
||||||
QB,
|
QB,
|
||||||
QS,
|
QS,
|
||||||
KW,
|
KW,
|
||||||
Kb,
|
|
||||||
KW_quant,
|
KW_quant,
|
||||||
KA,
|
KA,
|
||||||
KB,
|
KB,
|
||||||
KS,
|
KS,
|
||||||
VW,
|
VW,
|
||||||
Vb,
|
|
||||||
VW_quant,
|
VW_quant,
|
||||||
VA,
|
VA,
|
||||||
VB,
|
VB,
|
||||||
@@ -742,11 +688,10 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
b: torch.Tensor,
|
|
||||||
W_quant: QuantState | None,
|
W_quant: QuantState | None,
|
||||||
A: torch.Tensor,
|
A: torch.Tensor | None,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor | None,
|
||||||
s: float,
|
S: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for output projection with LoRA.
|
Forward pass for output projection with LoRA.
|
||||||
@@ -755,20 +700,19 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
ctx: Autograd context
|
ctx: Autograd context
|
||||||
X: Input tensor
|
X: Input tensor
|
||||||
W: Output projection weight
|
W: Output projection weight
|
||||||
b: Output projection bias
|
|
||||||
W_quant: Weight quantization state
|
W_quant: Weight quantization state
|
||||||
A: LoRA A matrix
|
A: LoRA A matrix
|
||||||
B: LoRA B matrix
|
B: LoRA B matrix
|
||||||
s: LoRA scaling factor
|
S: LoRA scaling factor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output projection result
|
Output projection tensor
|
||||||
"""
|
"""
|
||||||
XW = matmul_lora(X, W, b, W_quant, A, B, s)
|
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||||
ctx.custom_saved_tensors = (
|
ctx.custom_saved_tensors = (
|
||||||
W,
|
W,
|
||||||
W_quant,
|
W_quant,
|
||||||
s,
|
S,
|
||||||
)
|
)
|
||||||
ctx.save_for_backward(A, B, X)
|
ctx.save_for_backward(A, B, X)
|
||||||
|
|
||||||
@@ -783,9 +727,8 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
torch.Tensor | None,
|
||||||
torch.Tensor,
|
torch.Tensor | None,
|
||||||
torch.Tensor,
|
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
@@ -798,7 +741,7 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple containing gradients for all forward inputs
|
Tuple containing gradients for all forward inputs
|
||||||
"""
|
"""
|
||||||
W, W_quant, s = ctx.custom_saved_tensors
|
W, W_quant, S = ctx.custom_saved_tensors
|
||||||
A, B, X = ctx.saved_tensors
|
A, B, X = ctx.saved_tensors
|
||||||
|
|
||||||
batch, seq_len, hd = X.shape
|
batch, seq_len, hd = X.shape
|
||||||
@@ -808,19 +751,17 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
|
|
||||||
# Weight projection
|
# Weight projection
|
||||||
dY_X = X.t() @ dY
|
dY_X = X.t() @ dY
|
||||||
d_A = s * dY_X @ B
|
d_A = S * dY_X @ B
|
||||||
d_B = s * A @ dY_X
|
d_B = S * A @ dY_X
|
||||||
|
|
||||||
# Get derivative for dX
|
# Get derivative for dX
|
||||||
W = dequantize(W.t(), W_quant)
|
W = dequantize(W.t(), W_quant)
|
||||||
dX = dY @ W.t()
|
dX = dY @ W.t()
|
||||||
del W
|
del W
|
||||||
|
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
||||||
|
|
||||||
A, B = A.to(dtype), B.to(dtype)
|
# W, W_quant, A, B, S
|
||||||
dX += s * dY @ B @ A
|
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||||
|
|
||||||
# W, b, W_quant, A, B, s
|
|
||||||
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -833,7 +774,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
Transformed output tensor
|
Transformed output tensor
|
||||||
"""
|
"""
|
||||||
OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||||
output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
|
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -76,7 +76,6 @@ def load_lora(
|
|||||||
config_only: bool = False,
|
config_only: bool = False,
|
||||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||||
lora_target_modules = cfg.lora_target_modules or []
|
lora_target_modules = cfg.lora_target_modules or []
|
||||||
lora_target_parameters = cfg.lora_target_parameters or []
|
|
||||||
|
|
||||||
if cfg.lora_target_linear:
|
if cfg.lora_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
linear_names = find_all_linear_names(model)
|
||||||
@@ -107,7 +106,6 @@ def load_lora(
|
|||||||
r=cfg.lora_r,
|
r=cfg.lora_r,
|
||||||
lora_alpha=cfg.lora_alpha,
|
lora_alpha=cfg.lora_alpha,
|
||||||
target_modules=lora_target_modules,
|
target_modules=lora_target_modules,
|
||||||
target_parameters=lora_target_parameters,
|
|
||||||
layers_to_transform=cfg.peft_layers_to_transform,
|
layers_to_transform=cfg.peft_layers_to_transform,
|
||||||
layers_pattern=cfg.peft_layers_pattern,
|
layers_pattern=cfg.peft_layers_pattern,
|
||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""Model loader class implementation for loading, configuring, and patching various
|
||||||
Model loader class implementation for loading, configuring, and patching various models.
|
models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
@@ -13,7 +13,7 @@ import peft
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from accelerate import init_empty_weights
|
from accelerate import PartialState, init_empty_weights
|
||||||
from accelerate.parallelism_config import ParallelismConfig
|
from accelerate.parallelism_config import ParallelismConfig
|
||||||
from peft import (
|
from peft import (
|
||||||
PeftConfig,
|
PeftConfig,
|
||||||
@@ -22,7 +22,6 @@ from peft import (
|
|||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
from torch.distributed import DeviceMesh
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
@@ -50,11 +49,7 @@ from axolotl.loaders.utils import (
|
|||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
||||||
build_parallelism_config,
|
|
||||||
get_device_count,
|
|
||||||
get_device_type,
|
|
||||||
)
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
@@ -92,7 +87,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
use_parallel_config: bool | None = False
|
use_parallel_config: bool | None = False
|
||||||
parallelism_config: ParallelismConfig | None = None
|
parallelism_config: ParallelismConfig | None = None
|
||||||
device_mesh: DeviceMesh | None = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -208,8 +202,6 @@ class ModelLoader:
|
|||||||
self._set_device_map_config()
|
self._set_device_map_config()
|
||||||
if self.cfg.revision_of_model:
|
if self.cfg.revision_of_model:
|
||||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||||
if self.cfg.use_kernels:
|
|
||||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
|
||||||
self._set_quantization_config()
|
self._set_quantization_config()
|
||||||
self._set_attention_config()
|
self._set_attention_config()
|
||||||
|
|
||||||
@@ -308,10 +300,7 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Handle DeepSpeed Zero3
|
# Handle DeepSpeed Zero3
|
||||||
if (
|
if is_deepspeed_zero3_enabled():
|
||||||
is_deepspeed_zero3_enabled()
|
|
||||||
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
|
||||||
):
|
|
||||||
self._set_z3_leaf_modules()
|
self._set_z3_leaf_modules()
|
||||||
|
|
||||||
# Apply gradient checkpointing if needed
|
# Apply gradient checkpointing if needed
|
||||||
@@ -416,12 +405,85 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_parallel_config_kwargs(
|
||||||
|
world_size: int,
|
||||||
|
tensor_parallel_size: int = 1,
|
||||||
|
context_parallel_size: int = 1,
|
||||||
|
dp_shard_size: int | None = None,
|
||||||
|
dp_replicate_size: int | None = None,
|
||||||
|
is_fsdp: bool = False,
|
||||||
|
):
|
||||||
|
pc_kwargs = {}
|
||||||
|
remaining_world_size = world_size
|
||||||
|
|
||||||
|
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||||
|
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||||
|
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||||
|
|
||||||
|
if context_parallel_size and context_parallel_size > 1:
|
||||||
|
pc_kwargs["cp_size"] = context_parallel_size
|
||||||
|
remaining_world_size = remaining_world_size // context_parallel_size
|
||||||
|
|
||||||
|
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if dp_replicate_size and dp_replicate_size > 1:
|
||||||
|
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||||
|
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||||
|
|
||||||
|
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||||
|
if not is_fsdp:
|
||||||
|
raise ValueError(
|
||||||
|
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||||
|
"Please ensure you have configured FSDP using fsdp_config."
|
||||||
|
)
|
||||||
|
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||||
|
remaining_world_size = remaining_world_size // dp_shard_size
|
||||||
|
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||||
|
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||||
|
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||||
|
remaining_world_size = 1
|
||||||
|
|
||||||
|
if remaining_world_size > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||||
|
f"{pc_kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return pc_kwargs
|
||||||
|
|
||||||
def _set_parallel_config(self):
|
def _set_parallel_config(self):
|
||||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||||
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
|
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
||||||
if parallelism_config:
|
get_world_size(),
|
||||||
self.parallelism_config = parallelism_config
|
self.cfg.tensor_parallel_size,
|
||||||
self.device_mesh = device_mesh
|
self.cfg.context_parallel_size,
|
||||||
|
self.cfg.dp_shard_size,
|
||||||
|
self.cfg.dp_replicate_size,
|
||||||
|
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
if pc_kwargs:
|
||||||
|
self.parallelism_config = ParallelismConfig(
|
||||||
|
**pc_kwargs,
|
||||||
|
)
|
||||||
|
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
||||||
|
partial_state = PartialState()
|
||||||
|
# fmt: off
|
||||||
|
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
||||||
|
self.parallelism_config
|
||||||
|
)
|
||||||
|
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
||||||
|
device_mesh
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def _set_auto_model_loader(self):
|
def _set_auto_model_loader(self):
|
||||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
@@ -503,17 +565,8 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_quantization_config(self):
|
def _set_quantization_config(self):
|
||||||
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
||||||
|
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||||
if self.cfg.model_quantization_config == "Mxfp4Config":
|
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||||
from transformers import Mxfp4Config
|
|
||||||
|
|
||||||
mxfp4_kwargs = {}
|
|
||||||
if self.cfg.model_quantization_config_kwargs:
|
|
||||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
|
||||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
|
||||||
else:
|
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if not hasattr(self.model_config, "quantization_config"):
|
if not hasattr(self.model_config, "quantization_config"):
|
||||||
@@ -548,9 +601,7 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**self.model_config.quantization_config
|
**self.model_config.quantization_config
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||||
"load_in_4bit", False
|
|
||||||
):
|
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -576,9 +627,7 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
|
||||||
"load_in_8bit", False
|
|
||||||
):
|
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
}
|
}
|
||||||
@@ -599,9 +648,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_attention_config(self):
|
def _set_attention_config(self):
|
||||||
"""Sample packing uses custom FA2 patch"""
|
"""Sample packing uses custom FA2 patch"""
|
||||||
if self.cfg.attn_implementation:
|
if self.cfg.flex_attention:
|
||||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
|
||||||
elif self.cfg.flex_attention:
|
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flex_attention"
|
"flex_attention"
|
||||||
@@ -674,7 +721,7 @@ class ModelLoader:
|
|||||||
if self.cfg.tensor_parallel_size > 1:
|
if self.cfg.tensor_parallel_size > 1:
|
||||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
self.model_kwargs["tp_plan"] = "auto"
|
self.model_kwargs["tp_plan"] = "auto"
|
||||||
self.model_kwargs["device_mesh"] = self.device_mesh
|
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||||
if "device_map" in self.model_kwargs:
|
if "device_map" in self.model_kwargs:
|
||||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||||
|
|
||||||
@@ -690,18 +737,6 @@ class ModelLoader:
|
|||||||
elif self.is_qlora_and_fsdp_enabled:
|
elif self.is_qlora_and_fsdp_enabled:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|
||||||
if (
|
|
||||||
self.cfg.tensor_parallel_size <= 1
|
|
||||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
|
||||||
and self.cfg.fsdp_version == 2
|
|
||||||
):
|
|
||||||
# setting device_map for TP is not supported
|
|
||||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
|
||||||
if local_rank == 0:
|
|
||||||
self.model_kwargs["device_map"] = "cpu"
|
|
||||||
else:
|
|
||||||
self.model_kwargs["device_map"] = "meta"
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.is_qlora_and_fsdp_enabled
|
self.is_qlora_and_fsdp_enabled
|
||||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
@@ -810,9 +845,6 @@ class ModelLoader:
|
|||||||
self.model._tp_size = self.cfg.tensor_parallel_size
|
self.model._tp_size = self.cfg.tensor_parallel_size
|
||||||
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
||||||
|
|
||||||
if self.cfg.experimental_skip_move_to_device is not None:
|
|
||||||
skip_move_to_device = self.cfg.experimental_skip_move_to_device
|
|
||||||
|
|
||||||
return skip_move_to_device
|
return skip_move_to_device
|
||||||
|
|
||||||
def _set_z3_leaf_modules(self):
|
def _set_z3_leaf_modules(self):
|
||||||
|
|||||||
@@ -65,7 +65,6 @@ class PatchManager:
|
|||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_fsdp2_bnb_patches()
|
|
||||||
|
|
||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
@@ -73,19 +72,11 @@ class PatchManager:
|
|||||||
self._apply_voxtral_patches()
|
self._apply_voxtral_patches()
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||||
patch_evaluation_loop,
|
patch_prepare_from_posids,
|
||||||
patch_maybe_log_save_evaluate,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_fsdp2 = (
|
patch_prepare_from_posids()
|
||||||
self.cfg.torch_compile
|
|
||||||
and self.cfg.fsdp_config
|
|
||||||
and self.cfg.fsdp_version == 2
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_evaluation_loop(patch_fsdp2)
|
|
||||||
patch_maybe_log_save_evaluate()
|
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -112,14 +103,6 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
if self.cfg.context_parallel_size > 1 or (
|
|
||||||
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.accelerate.parallelism_config import (
|
|
||||||
patch_parallelism_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_parallelism_config()
|
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||||
|
|
||||||
@@ -277,23 +260,6 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_fsdp2_bnb_patches(self):
|
|
||||||
"""Apply FSDP2 BNB patches."""
|
|
||||||
if (
|
|
||||||
self.cfg.fsdp_config
|
|
||||||
and str(self.cfg.fsdp_version) == "2"
|
|
||||||
and self.cfg.adapter == "qlora"
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
|
||||||
apply_bnb_torch_function_patch,
|
|
||||||
apply_init_sharded_param_patch,
|
|
||||||
apply_init_unsharded_param_patch,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_bnb_torch_function_patch()
|
|
||||||
apply_init_sharded_param_patch()
|
|
||||||
apply_init_unsharded_param_patch()
|
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
from axolotl.monkeypatch.tiled_mlp import (
|
from axolotl.monkeypatch.tiled_mlp import (
|
||||||
@@ -364,21 +330,31 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def _patch_llama_flash_attention(self):
|
def _patch_llama_flash_attention(self, packed=False):
|
||||||
"""Apply Flash Attention patches for LLaMA models."""
|
"""Apply Flash Attention patches for LLaMA models."""
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.s2_attention:
|
if packed:
|
||||||
|
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
|
||||||
|
LOG.info("patching with flash attention for sample packing")
|
||||||
|
replace_llama_attn_with_flash_attn(
|
||||||
|
packed=True,
|
||||||
|
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||||
|
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||||
|
)
|
||||||
|
elif self.cfg.s2_attention:
|
||||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||||
replace_llama_attn_with_flash_attn(
|
replace_llama_attn_with_flash_attn(
|
||||||
|
packed=False,
|
||||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||||
use_shifted_sparse_attn=True,
|
use_shifted_sparse_attn=True,
|
||||||
)
|
)
|
||||||
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
|
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
|
||||||
replace_llama_attn_with_flash_attn(
|
replace_llama_attn_with_flash_attn(
|
||||||
|
packed=False,
|
||||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||||
)
|
)
|
||||||
@@ -409,7 +385,7 @@ class PatchManager:
|
|||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self._patch_llama_flash_attention()
|
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
|
||||||
elif self.cfg.xformers_attention:
|
elif self.cfg.xformers_attention:
|
||||||
self._patch_llama_xformers_attention()
|
self._patch_llama_xformers_attention()
|
||||||
elif self.cfg.sample_packing:
|
elif self.cfg.sample_packing:
|
||||||
@@ -432,12 +408,17 @@ class PatchManager:
|
|||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
is_xformers_swiglu_available,
|
is_xformers_swiglu_available,
|
||||||
replace_llama_mlp_with_swiglu,
|
replace_llama_mlp_with_swiglu,
|
||||||
|
replace_llama_qkv_with_fused,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
LOG.info("Patching with SwiGLU...")
|
LOG.info("Patching with SwiGLU...")
|
||||||
replace_llama_mlp_with_swiglu(model)
|
replace_llama_mlp_with_swiglu(model)
|
||||||
|
|
||||||
|
if self.cfg.flash_attn_fuse_qkv:
|
||||||
|
LOG.info("Patching with fused QKV...")
|
||||||
|
replace_llama_qkv_with_fused(model)
|
||||||
|
|
||||||
def _apply_unsloth_patches(self, model):
|
def _apply_unsloth_patches(self, model):
|
||||||
"""Apply unsloth optimization patches."""
|
"""Apply unsloth optimization patches."""
|
||||||
if self.cfg.unsloth_lora_mlp:
|
if self.cfg.unsloth_lora_mlp:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import functools
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
@@ -37,49 +36,25 @@ def fsdp2_load_full_state_dict(
|
|||||||
|
|
||||||
meta_sharded_sd = model.state_dict()
|
meta_sharded_sd = model.state_dict()
|
||||||
sharded_sd = {}
|
sharded_sd = {}
|
||||||
for param_name, sharded_meta_param in meta_sharded_sd.items():
|
for param_name, full_tensor in full_sd.items():
|
||||||
full_tensor = None
|
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||||
if _accelerator.is_main_process:
|
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||||
full_tensor = full_sd[param_name]
|
|
||||||
full_tensor = full_tensor.to(sharded_meta_param.dtype)
|
|
||||||
|
|
||||||
if hasattr(sharded_meta_param, "device_mesh"):
|
if hasattr(sharded_meta_param, "device_mesh"):
|
||||||
device_mesh = sharded_meta_param.device_mesh
|
|
||||||
if _accelerator.is_main_process:
|
|
||||||
full_tensor = full_tensor.to(device_mesh.device_type)
|
|
||||||
else:
|
|
||||||
full_tensor = torch.empty(
|
|
||||||
sharded_meta_param.size(),
|
|
||||||
device=device_mesh.device_type,
|
|
||||||
dtype=sharded_meta_param.dtype,
|
|
||||||
)
|
|
||||||
sharded_param = distribute_tensor(
|
sharded_param = distribute_tensor(
|
||||||
full_tensor,
|
full_tensor,
|
||||||
device_mesh,
|
sharded_meta_param.device_mesh,
|
||||||
sharded_meta_param.placements,
|
sharded_meta_param.placements,
|
||||||
src_data_rank=0,
|
src_data_rank=0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Non-sharded parameters
|
sharded_param = full_tensor
|
||||||
if _accelerator.is_main_process:
|
|
||||||
sharded_param = full_tensor.to(torch.device("cuda"))
|
|
||||||
else:
|
|
||||||
# broadcast manually
|
|
||||||
sharded_param = torch.empty_like(
|
|
||||||
sharded_meta_param,
|
|
||||||
device=torch.device("cuda"),
|
|
||||||
dtype=sharded_meta_param.dtype,
|
|
||||||
)
|
|
||||||
dist.broadcast(sharded_param, src=0)
|
|
||||||
|
|
||||||
if offload_to_cpu:
|
if offload_to_cpu:
|
||||||
sharded_param = sharded_param.cpu()
|
sharded_param = sharded_param.cpu()
|
||||||
|
|
||||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||||
|
|
||||||
del full_tensor
|
del full_tensor
|
||||||
full_sd[param_name] = None
|
full_sd[param_name] = None
|
||||||
|
|
||||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
"""
|
|
||||||
workaround to allow parallelism config for pure CP
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from accelerate import DistributedType
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_accelerator(self, accelerator):
|
|
||||||
_warnings = set()
|
|
||||||
if not accelerator.multi_device and self.total_size == 1:
|
|
||||||
# No distributed setup, valid parallelism config
|
|
||||||
return
|
|
||||||
|
|
||||||
# We need this to ensure DDP works
|
|
||||||
if self.total_size == 1:
|
|
||||||
self._set_size("dp_replicate", accelerator.num_processes)
|
|
||||||
|
|
||||||
if self.total_size != accelerator.num_processes:
|
|
||||||
raise ValueError(
|
|
||||||
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
|
||||||
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
|
||||||
f"dp_shard_size/tp_size/cp_size."
|
|
||||||
)
|
|
||||||
|
|
||||||
# allow parallelism config when not using fsdp if using pure context parallelism
|
|
||||||
allow_parallelism_config = False
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.cp_size > 1 # pylint: disable=chained-comparison
|
|
||||||
and self.dp_shard_size <= 1
|
|
||||||
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
|
|
||||||
):
|
|
||||||
allow_parallelism_config = True
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.total_size > 1
|
|
||||||
and not allow_parallelism_config
|
|
||||||
and not (accelerator.is_fsdp2 or accelerator.multi_device)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
|
||||||
)
|
|
||||||
|
|
||||||
for parallelism, size in self._sizes.items():
|
|
||||||
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
|
||||||
_warnings.add(
|
|
||||||
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
|
||||||
)
|
|
||||||
|
|
||||||
if _warnings and accelerator.is_main_process:
|
|
||||||
warnings.warn(
|
|
||||||
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
|
||||||
UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patched_is_fsdp2(self) -> bool:
|
|
||||||
"""
|
|
||||||
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
|
|
||||||
"""
|
|
||||||
# The new logic checks if fsdp_plugin exists before accessing its attributes
|
|
||||||
return (
|
|
||||||
self.distributed_type == DistributedType.FSDP
|
|
||||||
and self.fsdp_plugin
|
|
||||||
and self.fsdp_plugin.fsdp_version == 2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_parallelism_config():
|
|
||||||
from accelerate.accelerator import AcceleratorState, ParallelismConfig
|
|
||||||
|
|
||||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
|
||||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
"""
|
|
||||||
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
|
||||||
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
|
||||||
|
|
||||||
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
|
||||||
Params4bit parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def patched_torch_function(cls, func, types, args=(), kwargs=None):
|
|
||||||
"""
|
|
||||||
Patched version of Params4bit.__torch_function__ for preserving Params4bit
|
|
||||||
class identity and attributes.
|
|
||||||
"""
|
|
||||||
if kwargs is None:
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
if func in [torch.chunk, torch.split]:
|
|
||||||
tensor = args[0]
|
|
||||||
result = Parameter.__torch_function__(func, types, args, kwargs)
|
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
|
||||||
return tuple(
|
|
||||||
cls(
|
|
||||||
data=chunk,
|
|
||||||
requires_grad=tensor.requires_grad,
|
|
||||||
quant_state=tensor.quant_state,
|
|
||||||
blocksize=tensor.blocksize,
|
|
||||||
compress_statistics=tensor.compress_statistics,
|
|
||||||
quant_type=tensor.quant_type,
|
|
||||||
quant_storage=tensor.quant_storage,
|
|
||||||
module=tensor.module,
|
|
||||||
bnb_quantized=tensor.bnb_quantized,
|
|
||||||
)
|
|
||||||
for chunk in result
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
data=result,
|
|
||||||
requires_grad=tensor.requires_grad,
|
|
||||||
quant_state=tensor.quant_state,
|
|
||||||
blocksize=tensor.blocksize,
|
|
||||||
compress_statistics=tensor.compress_statistics,
|
|
||||||
quant_type=tensor.quant_type,
|
|
||||||
quant_storage=tensor.quant_storage,
|
|
||||||
module=tensor.module,
|
|
||||||
bnb_quantized=tensor.bnb_quantized,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Parameter.__torch_function__(func, types, args, kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
def apply_bnb_torch_function_patch():
|
|
||||||
"""
|
|
||||||
Patch Params4bit.__torch_function__ using Axolotl-style approach.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if patching succeeded, False otherwise.
|
|
||||||
"""
|
|
||||||
from bitsandbytes.nn.modules import Params4bit
|
|
||||||
|
|
||||||
Params4bit.__torch_function__ = classmethod(patched_torch_function)
|
|
||||||
|
|
||||||
LOG.info("Successfully patched Params4bit.__torch_function__")
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
def apply_init_sharded_param_patch():
|
|
||||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
|
||||||
|
|
||||||
# Get original source
|
|
||||||
original_source = inspect.getsource(FSDPParam._init_sharded_param)
|
|
||||||
original_source, _ = detab_code(original_source)
|
|
||||||
|
|
||||||
# Define the replacement
|
|
||||||
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
|
||||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
|
||||||
|
|
||||||
patched_param_creation = """ import bitsandbytes as bnb
|
|
||||||
if isinstance(param, bnb.nn.modules.Params4bit):
|
|
||||||
self.sharded_param = bnb.nn.modules.Params4bit(
|
|
||||||
data=sharded_param,
|
|
||||||
requires_grad=param.requires_grad,
|
|
||||||
quant_state=param.quant_state,
|
|
||||||
blocksize=param.blocksize,
|
|
||||||
compress_statistics=param.compress_statistics,
|
|
||||||
quant_type=param.quant_type,
|
|
||||||
quant_storage=param.quant_storage,
|
|
||||||
module=param.module,
|
|
||||||
bnb_quantized=param.bnb_quantized,
|
|
||||||
)
|
|
||||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
|
||||||
else:
|
|
||||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
|
||||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
|
||||||
|
|
||||||
# Apply the replacement
|
|
||||||
if original_param_creation in original_source:
|
|
||||||
patched_source = original_source.replace(
|
|
||||||
original_param_creation, patched_param_creation
|
|
||||||
)
|
|
||||||
patched_source = patched_source.replace(
|
|
||||||
"def _init_sharded_param(",
|
|
||||||
"def patched_init_sharded_param(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load necessary imports
|
|
||||||
module_name = FSDPParam.__module__
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(module):
|
|
||||||
if item in patched_source:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
|
|
||||||
# Replace the method
|
|
||||||
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
|
||||||
else:
|
|
||||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_init_unsharded_param_patch():
|
|
||||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
|
||||||
|
|
||||||
# Get original source
|
|
||||||
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
|
|
||||||
original_source, _ = detab_code(original_source)
|
|
||||||
|
|
||||||
# Define the replacement
|
|
||||||
original_param_creation = """ self._unsharded_param = nn.Parameter(
|
|
||||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
|
||||||
)"""
|
|
||||||
|
|
||||||
patched_param_creation = """ import bitsandbytes as bnb
|
|
||||||
local_tensor = self.sharded_param._local_tensor
|
|
||||||
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
|
|
||||||
self._unsharded_param = bnb.nn.modules.Params4bit(
|
|
||||||
data=unsharded_param,
|
|
||||||
requires_grad=self.sharded_param.requires_grad,
|
|
||||||
quant_state=local_tensor.quant_state,
|
|
||||||
blocksize=local_tensor.blocksize,
|
|
||||||
compress_statistics=local_tensor.compress_statistics,
|
|
||||||
quant_type=local_tensor.quant_type,
|
|
||||||
quant_storage=local_tensor.quant_storage,
|
|
||||||
module=local_tensor.module,
|
|
||||||
bnb_quantized=local_tensor.bnb_quantized,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._unsharded_param = nn.Parameter(
|
|
||||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
|
||||||
)"""
|
|
||||||
|
|
||||||
# Apply the replacement
|
|
||||||
if original_param_creation in original_source:
|
|
||||||
patched_source = original_source.replace(
|
|
||||||
original_param_creation, patched_param_creation
|
|
||||||
)
|
|
||||||
patched_source = patched_source.replace(
|
|
||||||
"def init_unsharded_param(",
|
|
||||||
"def patched_init_unsharded_param(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load necessary imports
|
|
||||||
module_name = FSDPParam.__module__
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(module):
|
|
||||||
if item in patched_source:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
|
|
||||||
# Replace the method
|
|
||||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
|
||||||
else:
|
|
||||||
LOG.warning("Could not find target code for patching")
|
|
||||||
@@ -3,26 +3,39 @@
|
|||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaAttention,
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import set_module_name
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
||||||
|
)
|
||||||
from flash_attn.flash_attn_interface import (
|
from flash_attn.flash_attn_interface import (
|
||||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
@@ -69,6 +82,19 @@ def replace_llama_mlp_with_swiglu(model):
|
|||||||
set_module_name(model, name, mlp)
|
set_module_name(model, name, mlp)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_qkv_with_fused(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, LlamaAttention):
|
||||||
|
qkv = FusedAttention(
|
||||||
|
module.config,
|
||||||
|
module.q_proj,
|
||||||
|
module.k_proj,
|
||||||
|
module.v_proj,
|
||||||
|
module.o_proj,
|
||||||
|
)
|
||||||
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
def patch_fa_llama_cross_entropy():
|
def patch_fa_llama_cross_entropy():
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||||
@@ -116,6 +142,7 @@ def patch_llama_rms_norm():
|
|||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
rms_norm: Optional[bool] = False,
|
rms_norm: Optional[bool] = False,
|
||||||
use_shifted_sparse_attn: Optional[bool] = False,
|
use_shifted_sparse_attn: Optional[bool] = False,
|
||||||
@@ -127,6 +154,16 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||||
flashattn_forward_with_s2attn
|
flashattn_forward_with_s2attn
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||||
|
flashattn_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
if packed:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
|
llama_model_forward
|
||||||
|
)
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
@@ -137,6 +174,49 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
patch_llama_rms_norm()
|
patch_llama_rms_norm()
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttention(LlamaAttention):
|
||||||
|
"""
|
||||||
|
Fused QKV Attention layer for incrementally improved training efficiency
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||||
|
):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
self.init_device = next(iter(q.state_dict().values())).device
|
||||||
|
|
||||||
|
# define equivalent fused qkv projection
|
||||||
|
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||||
|
self.qkv_proj = torch.nn.Linear(
|
||||||
|
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||||
|
)
|
||||||
|
self.o_proj = o
|
||||||
|
|
||||||
|
# overwrite initialized weights with pretrained weights
|
||||||
|
self.qkv_proj.weight.data = torch.cat(
|
||||||
|
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_training(self, model, name):
|
||||||
|
q_proj, k_proj, v_proj = torch.split(
|
||||||
|
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
new_attn = LlamaAttention(self.config)
|
||||||
|
new_attn.q_proj.weight.data = q_proj
|
||||||
|
new_attn.k_proj.weight.data = k_proj
|
||||||
|
new_attn.v_proj.weight.data = v_proj
|
||||||
|
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||||
|
|
||||||
|
set_module_name(model, name, new_attn)
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -275,3 +355,576 @@ def flashattn_forward_with_s2attn(
|
|||||||
.reshape(bsz, q_len, nheads, self.head_dim)
|
.reshape(bsz, q_len, nheads, self.head_dim)
|
||||||
)
|
)
|
||||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel
|
||||||
|
|
||||||
|
attention_mask: [bsz, q_len]
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if not hasattr(self, "pretraining_tp"):
|
||||||
|
self.pretraining_tp = 1
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (
|
||||||
|
self.num_key_value_heads * self.head_dim
|
||||||
|
) // self.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [
|
||||||
|
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [
|
||||||
|
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [
|
||||||
|
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if isinstance(self, FusedAttention):
|
||||||
|
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||||
|
self.out_features, dim=-1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
# [bsz, q_len, nh, hd]
|
||||||
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
warnings.warn(
|
||||||
|
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 start
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=(
|
||||||
|
attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=(
|
||||||
|
attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 end
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(
|
||||||
|
self.hidden_size // self.pretraining_tp, dim=1
|
||||||
|
)
|
||||||
|
attn_output = sum(
|
||||||
|
F.linear(attn_output[i], o_proj_slices[i])
|
||||||
|
for i in range(self.pretraining_tp)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
def output_pad_fn(output_unpad):
|
||||||
|
return pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
def output_pad_fn(output_unpad):
|
||||||
|
return rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[ # pylint: disable=unused-argument
|
||||||
|
torch.LongTensor
|
||||||
|
] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
cu_seqlens = None
|
||||||
|
max_seqlen = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
padding_mask = None
|
||||||
|
else:
|
||||||
|
if 0 in attention_mask:
|
||||||
|
padding_mask = attention_mask
|
||||||
|
else:
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
|
attention_mask = (
|
||||||
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
transformers.logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(
|
||||||
|
*inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(decoder_layer),
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
None,
|
||||||
|
padding_mask,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||||
|
"""
|
||||||
|
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|||||||
@@ -156,11 +156,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
|
|
||||||
return Llama4TextAttention
|
return Llama4TextAttention
|
||||||
|
|
||||||
if model_type == "mistral3":
|
|
||||||
from transformers.models.mistral.modeling_mistral import MistralAttention
|
|
||||||
|
|
||||||
return MistralAttention
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
@@ -395,6 +390,7 @@ def apply_lora_kernel_patches(
|
|||||||
]
|
]
|
||||||
can_patch_qkv = all(
|
can_patch_qkv = all(
|
||||||
hasattr(module, "lora_A")
|
hasattr(module, "lora_A")
|
||||||
|
and getattr(module, "base_layer", module).bias is None
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
@@ -404,8 +400,7 @@ def apply_lora_kernel_patches(
|
|||||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention QKV projections - requires LoRA "
|
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||||
"adapters and no lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
if cfg.lora_o_kernel:
|
if cfg.lora_o_kernel:
|
||||||
# Output patching
|
# Output patching
|
||||||
@@ -414,6 +409,7 @@ def apply_lora_kernel_patches(
|
|||||||
]
|
]
|
||||||
can_patch_o = all(
|
can_patch_o = all(
|
||||||
hasattr(module, "lora_A")
|
hasattr(module, "lora_A")
|
||||||
|
and getattr(module, "base_layer", module).bias is None
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
@@ -422,14 +418,14 @@ def apply_lora_kernel_patches(
|
|||||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention output projection - requires LoRA "
|
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||||
"adapters and no lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_mlp_kernel:
|
||||||
# MLP patching
|
# MLP patching
|
||||||
can_patch_mlp = all(
|
can_patch_mlp = all(
|
||||||
hasattr(proj, "lora_A")
|
hasattr(proj, "lora_A")
|
||||||
|
and getattr(proj, "base_layer", proj).bias is None
|
||||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||||
for proj in (gate_proj, up_proj, down_proj)
|
for proj in (gate_proj, up_proj, down_proj)
|
||||||
)
|
)
|
||||||
@@ -439,8 +435,7 @@ def apply_lora_kernel_patches(
|
|||||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||||
"lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.setLevel(original_level)
|
LOG.setLevel(original_level)
|
||||||
|
|||||||
@@ -3,14 +3,53 @@
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralAttention as OriginalMistralAttention,
|
||||||
|
)
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
|
)
|
||||||
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mistral_attn_with_flash_attn(
|
||||||
|
packed: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
|
_prepare_decoder_attention_mask
|
||||||
|
)
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||||
|
flashattn_forward
|
||||||
|
)
|
||||||
|
if packed:
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||||
|
MistralDecoderLayer
|
||||||
|
)
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||||
|
mistral_model_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral_cross_entropy():
|
def patch_mistral_cross_entropy():
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
@@ -18,3 +57,604 @@ def patch_mistral_cross_entropy():
|
|||||||
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||||
CrossEntropyLoss, inplace_backward=True
|
CrossEntropyLoss, inplace_backward=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def _make_sliding_window_causal_mask(
|
||||||
|
bsz: int,
|
||||||
|
tgt_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
past_key_values_length: int = 0,
|
||||||
|
sliding_window: int = 4096,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Make causal mask used for sliding window attention
|
||||||
|
"""
|
||||||
|
tensor = torch.full(
|
||||||
|
(tgt_len, tgt_len),
|
||||||
|
fill_value=1,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
mask = torch.tril(tensor, diagonal=0)
|
||||||
|
# make the mask banded to account for sliding window
|
||||||
|
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||||
|
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||||
|
mask = torch.log(mask).to(dtype)
|
||||||
|
|
||||||
|
if past_key_values_length > 0:
|
||||||
|
mask = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(
|
||||||
|
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
mask,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return mask[None, None, :, :].expand(
|
||||||
|
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
|
def _prepare_decoder_attention_mask(
|
||||||
|
self,
|
||||||
|
attention_mask,
|
||||||
|
input_shape,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
# [bsz, seq_len]
|
||||||
|
if attention_mask is None or sliding_window is None:
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||||
|
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||||
|
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||||
|
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||||
|
bsz=input_shape[0],
|
||||||
|
tgt_len=input_shape[1],
|
||||||
|
dtype=inputs_embeds.dtype,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
)
|
||||||
|
attention_mask = attention_mask + sliding_window_mask
|
||||||
|
else:
|
||||||
|
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def flashattn_forward(
|
||||||
|
self: OriginalMistralAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
use_sliding_windows = (
|
||||||
|
getattr(self.config, "sliding_window") is not None
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_sliding_windows:
|
||||||
|
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||||
|
else:
|
||||||
|
window_size = (-1, -1)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||||
|
if (
|
||||||
|
hasattr(self.config, "sliding_window")
|
||||||
|
and kv_seq_len > self.config.sliding_window
|
||||||
|
):
|
||||||
|
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||||
|
|
||||||
|
past_key = past_key_value[0]
|
||||||
|
past_value = past_key_value[1]
|
||||||
|
|
||||||
|
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||||
|
|
||||||
|
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||||
|
f" {past_key.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_value = (past_key, past_value) if use_cache else None
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = key_states.shape == query_states.shape
|
||||||
|
|
||||||
|
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=(
|
||||||
|
attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=(
|
||||||
|
attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if q_unpad.dtype != kv_unpad.dtype:
|
||||||
|
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
window_size=window_size,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
def output_pad_fn(output_unpad):
|
||||||
|
return pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
def output_pad_fn(output_unpad):
|
||||||
|
return rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[ # pylint: disable=unused-argument
|
||||||
|
torch.LongTensor
|
||||||
|
] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
cu_seqlens = None
|
||||||
|
max_seqlen = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
# embed positions
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=inputs_embeds.device,
|
||||||
|
)
|
||||||
|
attention_mask = (
|
||||||
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
transformers.logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = (
|
||||||
|
self._gradient_checkpointing_func( # pylint: disable=protected-access
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_value,
|
||||||
|
output_attentions,
|
||||||
|
None,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||||
|
"""
|
||||||
|
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
"smollm3",
|
"smollm3",
|
||||||
"gpt_oss",
|
"granite",
|
||||||
"arcee",
|
"granitemoe",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
78
src/axolotl/monkeypatch/trainer_eval_guard.py
Normal file
78
src/axolotl/monkeypatch/trainer_eval_guard.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP2 evals when using torch.compile
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
ORIGINAL_TRAINER_CODE = """
|
||||||
|
model.eval()
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_TRAINER_CODE = """
|
||||||
|
if hasattr(model, "eval") and callable(model.eval):
|
||||||
|
self.model.eval()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_evaluation_loop_code() -> str:
|
||||||
|
training_loop = inspect.getsource(Trainer.evaluation_loop)
|
||||||
|
return training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def check_evaluation_loop_is_patchable() -> bool:
|
||||||
|
eval_loop = get_evaluation_loop_code()
|
||||||
|
eval_loop, _ = detab_code(eval_loop)
|
||||||
|
return ORIGINAL_TRAINER_CODE in eval_loop
|
||||||
|
|
||||||
|
|
||||||
|
def patch_evaluation_loop_for_fsdp2():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
evaluation_loop = get_evaluation_loop_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
|
||||||
|
evaluation_loop
|
||||||
|
)
|
||||||
|
evaluation_loop, _ = detab_code(evaluation_loop)
|
||||||
|
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
|
||||||
|
return
|
||||||
|
|
||||||
|
evaluation_loop = evaluation_loop.replace(
|
||||||
|
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
|
||||||
|
)
|
||||||
|
evaluation_loop = evaluation_loop.replace(
|
||||||
|
"def evaluation_loop(",
|
||||||
|
"def _fixed_evaluation_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in evaluation_loop:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||||
|
Trainer.evaluation_loop = ( # pylint: disable=protected-access
|
||||||
|
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
||||||
|
|
||||||
|
see https://github.com/huggingface/transformers/pull/39653/files
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_from_posids(query, key, value, position_ids):
|
||||||
|
"""
|
||||||
|
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||||
|
All three query, key, value states will be flattened.
|
||||||
|
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||||
|
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||||
|
Arguments:
|
||||||
|
query (`torch.Tensor`):
|
||||||
|
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||||
|
key (`torch.Tensor`):
|
||||||
|
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||||
|
value (`torch.Tensor`):
|
||||||
|
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||||
|
position_ids (`torch.Tensor`):
|
||||||
|
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||||
|
Return:
|
||||||
|
query (`torch.Tensor`):
|
||||||
|
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||||
|
key (`torch.Tensor`):
|
||||||
|
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||||
|
value (`torch.Tensor`):
|
||||||
|
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||||
|
indices_q (`torch.Tensor`):
|
||||||
|
The indices of non-masked tokens from the flattened input target sequence.
|
||||||
|
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||||
|
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||||
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||||
|
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||||
|
"""
|
||||||
|
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||||
|
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||||
|
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||||
|
|
||||||
|
position_ids = position_ids.flatten()
|
||||||
|
indices_q = torch.arange(
|
||||||
|
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seq_lens = torch.cat(
|
||||||
|
(
|
||||||
|
indices_q[position_ids == 0],
|
||||||
|
torch.tensor(
|
||||||
|
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||||
|
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||||
|
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||||
|
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||||
|
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||||
|
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||||
|
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||||
|
# for some models (e.g. qwen2-vl).
|
||||||
|
max_length = cu_seq_lens.diff().max().item()
|
||||||
|
return (
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
indices_q,
|
||||||
|
(cu_seq_lens, cu_seq_lens),
|
||||||
|
(max_length, max_length),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_prepare_from_posids():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
||||||
|
_prepare_from_posids
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||||
|
"_prepare_from_posids",
|
||||||
|
_prepare_from_posids,
|
||||||
|
)
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
"""
|
|
||||||
Module for patching transformers Trainer loss calculation to use nanmean.
|
|
||||||
|
|
||||||
This is needed for context parallelism since chunks of the input sequences may be fully
|
|
||||||
masked and return NaNs in the loss calculation.
|
|
||||||
|
|
||||||
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
|
|
||||||
the other evaluation_loop patch because we can't patch the same code twice without
|
|
||||||
raising an OSError.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
ORIGINAL_EVAL_CODE = {
|
|
||||||
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
|
|
||||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
|
|
||||||
}
|
|
||||||
PATCHED_EVAL_CODE = {
|
|
||||||
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
|
|
||||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
|
||||||
}
|
|
||||||
|
|
||||||
ORIGINAL_FSDP2_CODE = """
|
|
||||||
model.eval()
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_FSDP2_CODE = """
|
|
||||||
if hasattr(model, "eval") and callable(model.eval):
|
|
||||||
self.model.eval()
|
|
||||||
"""
|
|
||||||
|
|
||||||
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
|
||||||
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
|
||||||
|
|
||||||
|
|
||||||
def check_evaluation_loop_is_patchable() -> bool:
|
|
||||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
|
||||||
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
|
|
||||||
|
|
||||||
|
|
||||||
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
|
|
||||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
|
||||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
|
||||||
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
def patch_evaluation_loop(patch_fsdp2: bool):
|
|
||||||
"""Patch the evaluation_loop method."""
|
|
||||||
# Check if already patched
|
|
||||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
|
||||||
LOG.info("Trainer.evaluation_loop already patched")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the patterns exist
|
|
||||||
try:
|
|
||||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
Trainer.evaluation = evaluation_loop_source
|
|
||||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
|
||||||
|
|
||||||
# Apply the nanmean patches
|
|
||||||
evaluation_loop_source = evaluation_loop_source.replace(
|
|
||||||
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
|
|
||||||
)
|
|
||||||
evaluation_loop_source = evaluation_loop_source.replace(
|
|
||||||
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply FSDP2 eval guard patch if needed
|
|
||||||
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
|
|
||||||
evaluation_loop_source = evaluation_loop_source.replace(
|
|
||||||
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
|
|
||||||
)
|
|
||||||
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
|
|
||||||
|
|
||||||
# Rename the function to avoid conflicts
|
|
||||||
evaluation_loop_source = evaluation_loop_source.replace(
|
|
||||||
"def evaluation_loop(",
|
|
||||||
"def axolotl_evaluation_loop(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the module for necessary imports
|
|
||||||
module_name = Trainer.__module__
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# Import necessary items from the module
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(module):
|
|
||||||
if item in evaluation_loop_source:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
# Execute the imports and patched method
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
|
|
||||||
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
|
||||||
Trainer.evaluation_loop = (
|
|
||||||
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_maybe_log_save_evaluate_is_patchable() -> bool:
|
|
||||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
|
||||||
return ORIGINAL_MAYBE_CODE in maybe_log_source
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
def patch_maybe_log_save_evaluate():
|
|
||||||
"""Patch the _maybe_log_save_evaluate method."""
|
|
||||||
# Check if already patched
|
|
||||||
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
|
|
||||||
LOG.info("Trainer._maybe_log_save_evaluate already patched")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if the patterns exist
|
|
||||||
try:
|
|
||||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
Trainer._original_maybe_log_save_evaluate = maybe_log_source
|
|
||||||
maybe_log_source, _ = detab_code(maybe_log_source)
|
|
||||||
|
|
||||||
# Apply the patch
|
|
||||||
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
|
|
||||||
|
|
||||||
# Rename the function to avoid conflicts
|
|
||||||
maybe_log_source = maybe_log_source.replace(
|
|
||||||
"def _maybe_log_save_evaluate(",
|
|
||||||
"def axolotl_maybe_log_save_evaluate(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the module for necessary imports
|
|
||||||
module_name = Trainer.__module__
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
|
|
||||||
# Import necessary items from the module
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(module):
|
|
||||||
if item in maybe_log_source:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
# Execute the imports and patched method
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
|
|
||||||
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
|
||||||
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
@@ -41,9 +41,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
field_messages: str = "messages",
|
field_messages: str = "messages",
|
||||||
field_system: str = "system",
|
field_system: str = "system",
|
||||||
field_tools: str = "tools",
|
field_tools: str = "tools",
|
||||||
field_thinking: str = "reasoning_content",
|
|
||||||
roles: dict[str, list[str]] | None = None,
|
roles: dict[str, list[str]] | None = None,
|
||||||
template_thinking_key: str | None = "reasoning_content",
|
|
||||||
chat_template_kwargs: dict[str, Any] | None = None,
|
chat_template_kwargs: dict[str, Any] | None = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
@@ -52,9 +50,8 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_property_mappings = {
|
message_property_mappings = {
|
||||||
"role": "role",
|
"role": "role",
|
||||||
"content": "content",
|
"content": "content",
|
||||||
|
"reasoning_content": "reasoning_content",
|
||||||
}
|
}
|
||||||
if template_thinking_key and field_thinking:
|
|
||||||
message_property_mappings[template_thinking_key] = field_thinking
|
|
||||||
|
|
||||||
if roles:
|
if roles:
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
@@ -77,12 +74,10 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.field_messages = field_messages
|
self.field_messages = field_messages
|
||||||
self.field_system = field_system
|
self.field_system = field_system
|
||||||
self.field_tools = field_tools
|
self.field_tools = field_tools
|
||||||
self.field_thinking = field_thinking
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor: ProcessorMixin | None = processor
|
self.processor: ProcessorMixin | None = processor
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_kwargs = chat_template_kwargs or {}
|
self.chat_template_kwargs = chat_template_kwargs or {}
|
||||||
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
@@ -747,9 +742,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
# get the thinking content
|
# get the thinking content
|
||||||
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
||||||
transformed_message[self.prompter.template_thinking_key] = (
|
transformed_message["reasoning_content"] = thinking_content.strip()
|
||||||
thinking_content.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
# take remainder of the content
|
# take remainder of the content
|
||||||
# strip whitespace from beginning of the remainder (thinking tokens)
|
# strip whitespace from beginning of the remainder (thinking tokens)
|
||||||
@@ -960,10 +953,6 @@ class StrategyLoader:
|
|||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||||
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
|
|
||||||
"template_thinking_key": dataset_config.get(
|
|
||||||
"template_thinking_key", "reasoning_content"
|
|
||||||
),
|
|
||||||
"roles": dataset_config.get("roles"),
|
"roles": dataset_config.get("roles"),
|
||||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||||
|
|||||||
@@ -218,7 +218,6 @@ def execute_training(
|
|||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
gather_outputs=cfg.rl is RLType.GRPO,
|
gather_outputs=cfg.rl is RLType.GRPO,
|
||||||
device_mesh=trainer.accelerator.torch_device_mesh,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -275,7 +274,7 @@ def save_trained_model(
|
|||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return
|
return
|
||||||
|
|
||||||
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
if trainer.is_fsdp_enabled:
|
||||||
if cfg.fsdp_config or cfg.fsdp:
|
if cfg.fsdp_config or cfg.fsdp:
|
||||||
if cfg.fsdp_config.final_state_dict_type:
|
if cfg.fsdp_config.final_state_dict_type:
|
||||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||||
@@ -567,10 +566,6 @@ def train(
|
|||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
|
|
||||||
# clear cache
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Save the trained model and cleanup
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
|
|||||||
62
src/axolotl/utils/chat_templates/templates/granite.jinja
Normal file
62
src/axolotl/utils/chat_templates/templates/granite.jinja
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
{# Alias tools -> available_tools #}
|
||||||
|
{%- if tools and not available_tools -%}
|
||||||
|
{%- set available_tools = tools -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{%- set system_message = messages[0]['content'] %}
|
||||||
|
{%- set loop_messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = "Knowledge Cutoff Date: April 2024.
|
||||||
|
Today's Date: " + strftime_now('%B %d, %Y') + ".
|
||||||
|
You are Granite, developed by IBM." %}
|
||||||
|
{%- if available_tools and documents %}
|
||||||
|
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
|
||||||
|
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||||
|
{%- elif available_tools %}
|
||||||
|
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
|
||||||
|
{%- elif documents %}
|
||||||
|
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||||
|
{%- elif thinking %}
|
||||||
|
{%- set system_message = system_message + " You are a helpful AI assistant.
|
||||||
|
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = system_message + " You are a helpful AI assistant." %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if 'citations' in controls and documents %}
|
||||||
|
{%- set system_message = system_message + '
|
||||||
|
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if 'hallucinations' in controls and documents %}
|
||||||
|
{%- set system_message = system_message + '
|
||||||
|
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set loop_messages = messages %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- if available_tools %}
|
||||||
|
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
|
||||||
|
{{- available_tools | tojson(indent=4) }}
|
||||||
|
{{- '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if documents %}
|
||||||
|
{%- for document in documents %}
|
||||||
|
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
|
||||||
|
' }}
|
||||||
|
{{- document['text'] }}
|
||||||
|
{{- '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in loop_messages %}
|
||||||
|
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- if loop.last and add_generation_prompt %}
|
||||||
|
{{- '<|start_of_role|>assistant' }}
|
||||||
|
{%- if controls %}
|
||||||
|
{{- ' ' + controls | tojson()}}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|end_of_role|>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
64
src/axolotl/utils/chat_templates/templates/granitemoe.jinja
Normal file
64
src/axolotl/utils/chat_templates/templates/granitemoe.jinja
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{%- set system_message = messages[0]['content'] %}
|
||||||
|
{%- set loop_messages = messages[1:] %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = "Knowledge Cutoff Date: April 2024.
|
||||||
|
Today's Date: " + strftime_now('%B %d, %Y') + ".
|
||||||
|
You are Granite, developed by IBM." %}
|
||||||
|
{%- if tools and documents %}
|
||||||
|
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
|
||||||
|
|
||||||
|
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||||
|
{%- elif tools %}
|
||||||
|
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
|
||||||
|
{%- elif documents %}
|
||||||
|
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set system_message = system_message + " You are a helpful AI assistant." %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if 'citations' in controls and documents %}
|
||||||
|
{%- set system_message = system_message + '
|
||||||
|
|
||||||
|
In your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if 'hallucinations' in controls and documents %}
|
||||||
|
{%- set system_message = system_message + '
|
||||||
|
|
||||||
|
Finally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set loop_messages = messages %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- if tools %}
|
||||||
|
{{- '<|start_of_role|>tools<|end_of_role|>' }}
|
||||||
|
{{- tools | tojson(indent=4) }}
|
||||||
|
{{- '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if documents %}
|
||||||
|
{{- '<|start_of_role|>documents<|end_of_role|>' }}
|
||||||
|
{%- for document in documents %}
|
||||||
|
{{- 'Document ' + loop.index0 | string + '
|
||||||
|
' }}
|
||||||
|
{{- document['text'] }}
|
||||||
|
{%- if not loop.last %}
|
||||||
|
{{- '
|
||||||
|
|
||||||
|
'}}
|
||||||
|
{%- endif%}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in loop_messages %}
|
||||||
|
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
|
||||||
|
' }}
|
||||||
|
{%- if loop.last and add_generation_prompt %}
|
||||||
|
{{- '<|start_of_role|>assistant' }}
|
||||||
|
{%- if controls %}
|
||||||
|
{{- ' ' + controls | tojson()}}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|end_of_role|>' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
@@ -161,8 +161,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
Collator for multipack specific to the using the BatchSampler
|
Collator for multipack specific to the using the BatchSampler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
squash_position_ids: bool = False
|
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if not isinstance(features[0], list):
|
if not isinstance(features[0], list):
|
||||||
features: List[List[dict]] = [features]
|
features: List[List[dict]] = [features]
|
||||||
@@ -178,15 +176,6 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature in item
|
if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
elif feature == "position_ids" and self.squash_position_ids:
|
|
||||||
arrays = [
|
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
|
||||||
]
|
|
||||||
# concatenate, get total length and create arange of new total position ids
|
|
||||||
position_ids = np.concatenate(arrays)
|
|
||||||
total_length = position_ids.shape[0]
|
|
||||||
position_ids = np.arange(total_length)
|
|
||||||
out_features[i][feature] = position_ids
|
|
||||||
else:
|
else:
|
||||||
arrays = [
|
arrays = [
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import inspect
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from accelerate import PartialState
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed import DeviceMesh
|
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.utils import ModelOutput
|
from transformers.utils import ModelOutput
|
||||||
@@ -194,7 +194,6 @@ class SequenceParallelContextManager:
|
|||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
gather_outputs: bool,
|
gather_outputs: bool,
|
||||||
device_mesh: DeviceMesh | None = None,
|
|
||||||
):
|
):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.context_parallel_size = context_parallel_size
|
self.context_parallel_size = context_parallel_size
|
||||||
@@ -202,7 +201,6 @@ class SequenceParallelContextManager:
|
|||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
self.gather_outputs = gather_outputs
|
self.gather_outputs = gather_outputs
|
||||||
self.device_mesh = device_mesh
|
|
||||||
|
|
||||||
self._register_ring_attn()
|
self._register_ring_attn()
|
||||||
|
|
||||||
@@ -242,8 +240,9 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
def _register_ring_attn(self):
|
def _register_ring_attn(self):
|
||||||
# Initialize ring attn for sequence parallelism
|
# Initialize ring attn for sequence parallelism
|
||||||
|
partial_state = PartialState()
|
||||||
register_ring_attn_from_device_mesh(
|
register_ring_attn_from_device_mesh(
|
||||||
device_mesh=self.device_mesh,
|
device_mesh=partial_state.device_mesh,
|
||||||
context_parallel_dim=("cp",),
|
context_parallel_dim=("cp",),
|
||||||
heads_k_stride=self.heads_k_stride,
|
heads_k_stride=self.heads_k_stride,
|
||||||
ring_attn_func=self.ring_attn_func,
|
ring_attn_func=self.ring_attn_func,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Data handling specific to SFT."""
|
"""Data handling specific to SFT."""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@@ -105,9 +104,6 @@ def _prepare_standard_dataset(
|
|||||||
finally:
|
finally:
|
||||||
loader.cleanup()
|
loader.cleanup()
|
||||||
|
|
||||||
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
|
|
||||||
return train_dataset, eval_dataset, -1, prompters
|
|
||||||
|
|
||||||
# Validate sample packing configuration for evaluation
|
# Validate sample packing configuration for evaluation
|
||||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from datetime import timedelta
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import PartialState
|
from accelerate import PartialState
|
||||||
from accelerate.utils import ParallelismConfig
|
|
||||||
from transformers.utils.import_utils import (
|
from transformers.utils.import_utils import (
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_mps_available,
|
is_torch_mps_available,
|
||||||
@@ -51,10 +50,7 @@ def init_distributed_state():
|
|||||||
global distributed_state # pylint: disable=global-statement
|
global distributed_state # pylint: disable=global-statement
|
||||||
if distributed_state is None:
|
if distributed_state is None:
|
||||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||||
try:
|
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def get_distributed_state() -> PartialState | None:
|
def get_distributed_state() -> PartialState | None:
|
||||||
@@ -294,77 +290,3 @@ def reduce_and_broadcast(fn1, fn2):
|
|||||||
# Use compute_and_broadcast to compute the reduced value on the main process
|
# Use compute_and_broadcast to compute the reduced value on the main process
|
||||||
# and then broadcast it to all ranks
|
# and then broadcast it to all ranks
|
||||||
return compute_and_broadcast(lambda: fn2(gathered_values))
|
return compute_and_broadcast(lambda: fn2(gathered_values))
|
||||||
|
|
||||||
|
|
||||||
def build_parallelism_config(cfg):
|
|
||||||
pc_kwargs = _get_parallel_config_kwargs(
|
|
||||||
get_world_size(),
|
|
||||||
cfg.tensor_parallel_size,
|
|
||||||
cfg.context_parallel_size,
|
|
||||||
cfg.dp_shard_size,
|
|
||||||
cfg.dp_replicate_size,
|
|
||||||
bool(cfg.fsdp or cfg.fsdp_config),
|
|
||||||
)
|
|
||||||
|
|
||||||
if pc_kwargs:
|
|
||||||
parallelism_config = ParallelismConfig(
|
|
||||||
**pc_kwargs,
|
|
||||||
)
|
|
||||||
device_mesh = parallelism_config.build_device_mesh("cuda")
|
|
||||||
|
|
||||||
return parallelism_config, device_mesh
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_parallel_config_kwargs(
|
|
||||||
world_size: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
context_parallel_size: int = 1,
|
|
||||||
dp_shard_size: int | None = None,
|
|
||||||
dp_replicate_size: int | None = None,
|
|
||||||
is_fsdp: bool = False,
|
|
||||||
):
|
|
||||||
pc_kwargs = {}
|
|
||||||
remaining_world_size = world_size
|
|
||||||
|
|
||||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
|
||||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
|
||||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
|
||||||
|
|
||||||
if context_parallel_size and context_parallel_size > 1:
|
|
||||||
pc_kwargs["cp_size"] = context_parallel_size
|
|
||||||
remaining_world_size = remaining_world_size // context_parallel_size
|
|
||||||
|
|
||||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if dp_replicate_size and dp_replicate_size > 1:
|
|
||||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
|
||||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
|
||||||
|
|
||||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
|
||||||
if not is_fsdp:
|
|
||||||
raise ValueError(
|
|
||||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
|
||||||
"Please ensure you have configured FSDP using fsdp_config."
|
|
||||||
)
|
|
||||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
|
||||||
remaining_world_size = remaining_world_size // dp_shard_size
|
|
||||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
|
||||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
|
||||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
|
||||||
f"{pc_kwargs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return pc_kwargs
|
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
"""
|
|
||||||
Helper for importing modules from strings
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
|
|
||||||
def get_cls_from_module_str(module_str: str):
|
|
||||||
# use importlib to dynamically load the reward function from the module
|
|
||||||
if not isinstance(module_str, str) or not module_str.strip():
|
|
||||||
raise ValueError("module_str must be a non-empty string")
|
|
||||||
|
|
||||||
parts = module_str.split(".")
|
|
||||||
if len(parts) < 2:
|
|
||||||
raise ValueError(f"Invalid module string format: {module_str}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
cls_name = parts[-1]
|
|
||||||
module_path = ".".join(parts[:-1])
|
|
||||||
mod = importlib.import_module(module_path)
|
|
||||||
mod_cls = getattr(mod, cls_name)
|
|
||||||
return mod_cls
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(f"Failed to import module '{module_path}': {e}") from e
|
|
||||||
except AttributeError as e:
|
|
||||||
raise AttributeError(
|
|
||||||
f"Class '{cls_name}' not found in module '{module_path}': {e}"
|
|
||||||
) from e
|
|
||||||
@@ -4,7 +4,6 @@ import math
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
@@ -46,10 +45,8 @@ class RexLR(LRScheduler):
|
|||||||
|
|
||||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
initial_lr = group["lr"]
|
group.setdefault("initial_lr", group["lr"])
|
||||||
if isinstance(initial_lr, Tensor):
|
|
||||||
initial_lr = initial_lr.clone()
|
|
||||||
group.setdefault("initial_lr", initial_lr)
|
|
||||||
# Pass self.last_step as last_epoch to the parent.
|
# Pass self.last_step as last_epoch to the parent.
|
||||||
super().__init__(optimizer, last_epoch=self.last_step)
|
super().__init__(optimizer, last_epoch=self.last_step)
|
||||||
|
|
||||||
|
|||||||
@@ -110,13 +110,6 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_cls: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "module to custom trainer class to use for training"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
rl: RLType | None = Field(
|
rl: RLType | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -538,6 +531,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Whether to use flash-attention rms norm implementation - advanced use only"
|
"description": "Whether to use flash-attention rms norm implementation - advanced use only"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
flash_attn_fuse_qkv: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to fuse QKV into a single operation"
|
||||||
|
},
|
||||||
|
)
|
||||||
flash_attn_fuse_mlp: bool | None = Field(
|
flash_attn_fuse_mlp: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -551,13 +550,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
eager_attention: bool | None = None
|
eager_attention: bool | None = None
|
||||||
|
|
||||||
attn_implementation: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Specify a custom attention implementation, used mostly for kernels."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: bool | None = None
|
unsloth_cross_entropy_loss: bool | None = None
|
||||||
unsloth_lora_mlp: bool | None = None
|
unsloth_lora_mlp: bool | None = None
|
||||||
unsloth_lora_qkv: bool | None = None
|
unsloth_lora_qkv: bool | None = None
|
||||||
|
|||||||
@@ -118,18 +118,6 @@ class SFTDataset(BaseModel):
|
|||||||
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
field_thinking: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": 'Key containing the reasoning trace (default: "reasoning_content").'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
template_thinking_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "The key the chat template expects that indicates the reasoning trace."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# deprecated, use message_property_mappings
|
# deprecated, use message_property_mappings
|
||||||
message_field_role: str | None = None
|
message_field_role: str | None = None
|
||||||
# deprecated, use message_property_mappings
|
# deprecated, use message_property_mappings
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ class ChatTemplate(str, Enum):
|
|||||||
command_a_tool_use = "command_a_tool_use"
|
command_a_tool_use = "command_a_tool_use"
|
||||||
command_a_rag = "command_a_rag"
|
command_a_rag = "command_a_rag"
|
||||||
aya = "aya"
|
aya = "aya"
|
||||||
|
granite = "granite"
|
||||||
|
granitemoe = "granitemoe"
|
||||||
|
|
||||||
|
|
||||||
class CustomSupportedOptimizers(str, Enum):
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
@@ -79,7 +81,6 @@ class CustomSupportedOptimizers(str, Enum):
|
|||||||
adopt_adamw = "adopt_adamw"
|
adopt_adamw = "adopt_adamw"
|
||||||
came_pytorch = "came_pytorch"
|
came_pytorch = "came_pytorch"
|
||||||
muon = "muon"
|
muon = "muon"
|
||||||
dion = "dion"
|
|
||||||
|
|
||||||
|
|
||||||
class RingAttnFunc(str, Enum):
|
class RingAttnFunc(str, Enum):
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Pydantic models for model input / output, etc. configuration"""
|
"""Pydantic models for model input / output, etc. configuration"""
|
||||||
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -64,28 +62,6 @@ class ModelInputConfig(BaseModel):
|
|||||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||||
)
|
)
|
||||||
|
|
||||||
experimental_skip_move_to_device: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Don't move the model to the device before sharding. "
|
|
||||||
"This is an experimental feature that may be included in the future as the default."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
use_kernels: bool | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Use custom kernels, e.g. MegaBlocks."},
|
|
||||||
)
|
|
||||||
|
|
||||||
model_quantization_config: Literal["Mxfp4Config"] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Model loading quantization config"},
|
|
||||||
)
|
|
||||||
model_quantization_config_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "kwargs for model quantization config"},
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_trust_remote_code(cls, trust_remote_code):
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
|
|||||||
@@ -54,7 +54,6 @@ class LoraConfig(BaseModel):
|
|||||||
lora_alpha: int | None = None
|
lora_alpha: int | None = None
|
||||||
lora_fan_in_fan_out: bool | None = None
|
lora_fan_in_fan_out: bool | None = None
|
||||||
lora_target_modules: str | list[str] | None = None
|
lora_target_modules: str | list[str] | None = None
|
||||||
lora_target_parameters: str | list[str] | None = None
|
|
||||||
lora_target_linear: bool | None = Field(
|
lora_target_linear: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "If true, will target all linear modules"},
|
json_schema_extra={"description": "If true, will target all linear modules"},
|
||||||
|
|||||||
@@ -138,26 +138,6 @@ class HyperparametersConfig(BaseModel):
|
|||||||
adam_beta3: float | None = Field(
|
adam_beta3: float | None = Field(
|
||||||
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
|
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
|
||||||
)
|
)
|
||||||
|
|
||||||
dion_lr: float | None = Field(
|
|
||||||
default=None, json_schema_extra={"description": "Dion Optimizer learning rate"}
|
|
||||||
)
|
|
||||||
dion_momentum: float | None = Field(
|
|
||||||
default=None, json_schema_extra={"description": "Dion Optimizer momentum"}
|
|
||||||
)
|
|
||||||
dion_rank_fraction: float | None = Field(
|
|
||||||
default=1.0,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Dion Optimizer: r/d fraction for low-rank approximation. Used to compute the low-rank dimension."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dion_rank_multiple_of: int | None = Field(
|
|
||||||
default=1,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Dion Optimizer: Round up the low-rank dimension to a multiple of this number. This may be useful to ensure even sharding."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
max_grad_norm: float | None = Field(
|
max_grad_norm: float | None = Field(
|
||||||
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
|
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -559,6 +559,20 @@ class LoRAValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_lora_8bit(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("lora_mlp_kernel")
|
||||||
|
or data.get("lora_qkv_kernel")
|
||||||
|
or data.get("lora_o_kernel")
|
||||||
|
):
|
||||||
|
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_axolotl_unsloth(cls, data):
|
def check_lora_axolotl_unsloth(cls, data):
|
||||||
@@ -577,7 +591,9 @@ class LoRAValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_fused_lora(self):
|
def check_fused_lora(self):
|
||||||
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
|
if self.adapter in ["lora", "qlora"] and (
|
||||||
|
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
|
||||||
|
):
|
||||||
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -603,7 +619,7 @@ class LoRAValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernels_8bit(cls, data):
|
def check_lora_kernel_8bit(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
@@ -611,39 +627,36 @@ class LoRAValidationMixin:
|
|||||||
):
|
):
|
||||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with 8-bit LoRA"
|
||||||
"compatible with 8-bit LoRA a the moment."
|
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernels_dora(cls, data):
|
def check_lora_kernel_rl(cls, data):
|
||||||
if (
|
|
||||||
data.get("lora_mlp_kernel")
|
|
||||||
or data.get("lora_qkv_kernel")
|
|
||||||
or data.get("lora_o_kernel")
|
|
||||||
) and data.get("peft_use_dora"):
|
|
||||||
raise ValueError(
|
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
|
||||||
"compatible with DoRA at the moment."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_lora_kernels_rl(cls, data):
|
|
||||||
if (
|
if (
|
||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
or data.get("lora_o_kernel")
|
or data.get("lora_o_kernel")
|
||||||
) and data.get("rl"):
|
) and data.get("rl"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with RL at the moment."
|
||||||
"compatible with RL at the moment."
|
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_lora_dropout_parameters(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("lora_dropout", 0.0)
|
||||||
|
and data.get("lora_dropout") > 0.0
|
||||||
|
and data.get("lora_target_parameters")
|
||||||
|
):
|
||||||
|
# lora.ParamWrapper does not work with lora_dropout != 0
|
||||||
|
raise ValueError(
|
||||||
|
"`lora_dropout` does not work when using `lora_target_parameters`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RLValidationMixin:
|
class RLValidationMixin:
|
||||||
"""Validation methods related to RL training configuration."""
|
"""Validation methods related to RL training configuration."""
|
||||||
@@ -972,16 +985,6 @@ class SystemValidationMixin:
|
|||||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_model_quantization_config_vs_bnb(cls, data):
|
|
||||||
if data.get("model_quantization_config"):
|
|
||||||
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
|
||||||
raise ValueError(
|
|
||||||
"model_quantization_config and load_in_8bit or load_in_4bit cannot be used together."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_npu_config(cls, data):
|
def check_npu_config(cls, data):
|
||||||
@@ -1147,19 +1150,6 @@ class ModelCompatibilityValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_gpt_oss_fsdp_loading(cls, data):
|
|
||||||
if data.get("model_quantization_config", "") == "Mxfp4Config":
|
|
||||||
if (
|
|
||||||
data.get("fsdp_config", {}).get("cpu_ram_efficient_loading", False)
|
|
||||||
is True
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"FSDP cpu_ram_efficient_loading is not supported for Mxfp4Config model quantization."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class ComplexValidationMixin:
|
class ComplexValidationMixin:
|
||||||
"""Complex validation methods that involve multiple systems."""
|
"""Complex validation methods that involve multiple systems."""
|
||||||
@@ -1205,7 +1195,7 @@ class ComplexValidationMixin:
|
|||||||
"ReLoRA is not compatible with the one_cycle scheduler"
|
"ReLoRA is not compatible with the one_cycle scheduler"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.flash_attn_fuse_mlp:
|
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
|
||||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
|||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||||
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
from axolotl.utils.distributed import init_distributed_state, reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -596,25 +597,6 @@ def setup_fsdp_envs(cfg):
|
|||||||
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
|
os.environ["FSDP_RESHARD_AFTER_FORWARD"] = "true"
|
||||||
|
|
||||||
|
|
||||||
def setup_parallelism_envs(cfg):
|
|
||||||
set_accelerate_parallelism_config = False
|
|
||||||
if cfg.tensor_parallel_size and cfg.tensor_parallel_size > 1:
|
|
||||||
set_accelerate_parallelism_config = True
|
|
||||||
os.environ["PARALLELISM_CONFIG_TP_SIZE"] = str(cfg.tensor_parallel_size)
|
|
||||||
if cfg.dp_shard_size and cfg.dp_shard_size > 1:
|
|
||||||
set_accelerate_parallelism_config = True
|
|
||||||
os.environ["PARALLELISM_CONFIG_DP_SHARD_SIZE"] = str(cfg.dp_shard_size)
|
|
||||||
if cfg.dp_replicate_size and cfg.dp_replicate_size > 1:
|
|
||||||
set_accelerate_parallelism_config = True
|
|
||||||
os.environ["PARALLELISM_CONFIG_DP_REPLICATE_SIZE"] = str(cfg.dp_replicate_size)
|
|
||||||
if cfg.context_parallel_size and cfg.context_parallel_size > 1:
|
|
||||||
set_accelerate_parallelism_config = True
|
|
||||||
os.environ["PARALLELISM_CONFIG_CP_SIZE"] = str(cfg.context_parallel_size)
|
|
||||||
os.environ["ACCELERATE_ALLOW_CP_STANDALONE"] = "true"
|
|
||||||
if set_accelerate_parallelism_config:
|
|
||||||
os.environ["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def prepare_optim_env(cfg):
|
||||||
if not check_cuda_p2p_ib_support():
|
if not check_cuda_p2p_ib_support():
|
||||||
if os.getenv("NCCL_P2P_DISABLE") is None:
|
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||||
@@ -633,7 +615,6 @@ def prepare_optim_env(cfg):
|
|||||||
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||||
setup_deepspeed_env(cfg, stage=stage)
|
setup_deepspeed_env(cfg, stage=stage)
|
||||||
|
|
||||||
setup_parallelism_envs(cfg)
|
|
||||||
setup_torch_compile_env(cfg)
|
setup_torch_compile_env(cfg)
|
||||||
|
|
||||||
if cfg.fp8:
|
if cfg.fp8:
|
||||||
@@ -686,6 +667,8 @@ def setup_trainer(
|
|||||||
"""
|
"""
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
|
if cfg.torch_compile and cfg.fsdp_config and cfg.fsdp_version == 2:
|
||||||
|
patch_evaluation_loop_for_fsdp2()
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model, tokenizer, processor)
|
||||||
trainer_builder.model_ref = model_ref
|
trainer_builder.model_ref = model_ref
|
||||||
|
|||||||
@@ -47,9 +47,7 @@ class BaseCliTest:
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
with patch("subprocess.run") as mock:
|
||||||
|
|
||||||
with patch(mock_fn) as mock:
|
|
||||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||||
|
|
||||||
assert mock.called
|
assert mock.called
|
||||||
@@ -67,12 +65,8 @@ class BaseCliTest:
|
|||||||
if train:
|
if train:
|
||||||
expected.append("--shard=False")
|
expected.append("--shard=False")
|
||||||
|
|
||||||
if command == "train":
|
assert mock.call_args.args[0] == expected
|
||||||
assert mock.call_args.args[0] == "accelerate"
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
assert mock.call_args.args[1] == expected
|
|
||||||
else:
|
|
||||||
assert mock.call_args.args[0] == expected
|
|
||||||
assert mock.call_args.kwargs == {"check": True}
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|
||||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("os.execvpe") as mock_subprocess:
|
with patch("subprocess.run") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -104,7 +104,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify launcher args are passed to torchrun
|
# Verify launcher args are passed to torchrun
|
||||||
called_cmd = mock_subprocess.call_args.args[1]
|
called_cmd = mock_subprocess.call_args.args[0]
|
||||||
assert called_cmd[0] == "torchrun"
|
assert called_cmd[0] == "torchrun"
|
||||||
assert "--nproc_per_node=2" in called_cmd
|
assert "--nproc_per_node=2" in called_cmd
|
||||||
assert "--nnodes=1" in called_cmd
|
assert "--nnodes=1" in called_cmd
|
||||||
@@ -118,7 +118,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("os.execvpe") as mock_subprocess:
|
with patch("subprocess.run") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -137,8 +137,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify launcher args are passed to accelerate
|
# Verify launcher args are passed to accelerate
|
||||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
called_cmd = mock_subprocess.call_args.args[0]
|
||||||
called_cmd = mock_subprocess.call_args.args[1]
|
|
||||||
assert called_cmd[0] == "accelerate"
|
assert called_cmd[0] == "accelerate"
|
||||||
assert called_cmd[1] == "launch"
|
assert called_cmd[1] == "launch"
|
||||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||||
@@ -153,7 +152,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("os.execvpe") as mock_subprocess:
|
with patch("subprocess.run") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -171,8 +170,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
# Verify no launcher args contamination
|
# Verify no launcher args contamination
|
||||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
called_cmd = mock_subprocess.call_args.args[0]
|
||||||
called_cmd = mock_subprocess.call_args.args[1]
|
|
||||||
assert called_cmd[0] == "accelerate"
|
assert called_cmd[0] == "accelerate"
|
||||||
assert called_cmd[1] == "launch"
|
assert called_cmd[1] == "launch"
|
||||||
# Should not contain any extra launcher args
|
# Should not contain any extra launcher args
|
||||||
@@ -188,7 +186,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
config_path = tmp_path / "config.yml"
|
config_path = tmp_path / "config.yml"
|
||||||
config_path.write_text(valid_test_config)
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
with patch("os.execvpe") as mock_subprocess:
|
with patch("subprocess.run") as mock_subprocess:
|
||||||
result = cli_runner.invoke(
|
result = cli_runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@@ -209,8 +207,7 @@ class TestTrainCommand(BaseCliTest):
|
|||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
mock_subprocess.assert_called_once()
|
mock_subprocess.assert_called_once()
|
||||||
|
|
||||||
assert mock_subprocess.call_args.args[0] == "torchrun"
|
called_cmd = mock_subprocess.call_args.args[0]
|
||||||
called_cmd = mock_subprocess.call_args.args[1]
|
|
||||||
# Verify launcher args
|
# Verify launcher args
|
||||||
assert "--nproc_per_node=8" in called_cmd
|
assert "--nproc_per_node=8" in called_cmd
|
||||||
# Verify axolotl args are also present
|
# Verify axolotl args are also present
|
||||||
|
|||||||
@@ -281,9 +281,7 @@ class TestHFRLTrainerBuilder:
|
|||||||
# Other settings
|
# Other settings
|
||||||
assert training_arguments.dataloader_num_workers == 1
|
assert training_arguments.dataloader_num_workers == 1
|
||||||
assert training_arguments.dataloader_pin_memory is True
|
assert training_arguments.dataloader_pin_memory is True
|
||||||
|
assert training_arguments.gradient_checkpointing is False
|
||||||
# TODO(wing): restore once trl releases 0.22.0
|
|
||||||
# assert training_arguments.gradient_checkpointing is True
|
|
||||||
|
|
||||||
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ def sample_tensors():
|
|||||||
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
||||||
),
|
),
|
||||||
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
||||||
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
|
|
||||||
"scale": 0.5,
|
"scale": 0.5,
|
||||||
"shapes": {
|
"shapes": {
|
||||||
"batch": batch_size,
|
"batch": batch_size,
|
||||||
@@ -104,24 +103,23 @@ def mock_proj():
|
|||||||
def test_get_lora_parameters(mock_proj):
|
def test_get_lora_parameters(mock_proj):
|
||||||
"""Tests get_lora_parameters function"""
|
"""Tests get_lora_parameters function"""
|
||||||
# Test with LoRA enabled
|
# Test with LoRA enabled
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
|
|
||||||
assert isinstance(W, torch.Tensor)
|
assert isinstance(W, torch.Tensor)
|
||||||
assert W.shape == (128, 64)
|
assert W.shape == (128, 64)
|
||||||
assert b.shape == (128,)
|
|
||||||
assert A.shape == (8, 64)
|
assert A.shape == (8, 64)
|
||||||
assert B.shape == (128, 8)
|
assert B.shape == (128, 8)
|
||||||
assert s == 0.5
|
assert s == 0.5
|
||||||
|
|
||||||
# Test with LoRA disabled
|
# Test with LoRA disabled
|
||||||
mock_proj.disable_adapters = True
|
mock_proj.disable_adapters = True
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
# Test with merged state
|
# Test with merged state
|
||||||
mock_proj.disable_adapters = False
|
mock_proj.disable_adapters = False
|
||||||
mock_proj.merged = True
|
mock_proj.merged = True
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
|
|
||||||
@@ -129,7 +127,6 @@ def test_matmul_lora(sample_tensors):
|
|||||||
"""Tests matmul_lora function"""
|
"""Tests matmul_lora function"""
|
||||||
X = sample_tensors["X"]
|
X = sample_tensors["X"]
|
||||||
W = sample_tensors["W"]
|
W = sample_tensors["W"]
|
||||||
b = sample_tensors["b"]
|
|
||||||
scale = sample_tensors["scale"]
|
scale = sample_tensors["scale"]
|
||||||
|
|
||||||
shapes = sample_tensors["shapes"]
|
shapes = sample_tensors["shapes"]
|
||||||
@@ -141,20 +138,19 @@ def test_matmul_lora(sample_tensors):
|
|||||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
# Test base matmul
|
# Test base matmul
|
||||||
out1 = matmul_lora(X, W, b, None, None, None, None)
|
out1 = matmul_lora(X, W, None, None, None, None)
|
||||||
matmul = torch.matmul(X, W.t())
|
expected1 = torch.matmul(X, W.t())
|
||||||
expected1 = matmul + b
|
|
||||||
assert torch.allclose(out1, expected1, rtol=1e-3)
|
assert torch.allclose(out1, expected1, rtol=1e-3)
|
||||||
|
|
||||||
# Test with LoRA
|
# Test with LoRA
|
||||||
out2 = matmul_lora(X, W, b, None, A, B, scale)
|
out2 = matmul_lora(X, W, None, A, B, scale)
|
||||||
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
||||||
expected2 = matmul + lora_term + b
|
expected2 = expected1 + lora_term
|
||||||
assert torch.allclose(out2, expected2, rtol=1e-3)
|
assert torch.allclose(out2, expected2, rtol=1e-3)
|
||||||
|
|
||||||
# Test 3D input reshaping
|
# Test 3D input reshaping
|
||||||
X_3d = X.clone()
|
X_3d = X.clone()
|
||||||
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
|
out3 = matmul_lora(X_3d, W, None, A, B, scale)
|
||||||
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||||
|
|
||||||
|
|
||||||
@@ -179,19 +175,16 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
|||||||
output = LoRA_MLP.apply(
|
output = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
gate_proj.weight,
|
gate_proj.weight,
|
||||||
gate_proj.bias,
|
|
||||||
None, # gate_quant
|
None, # gate_quant
|
||||||
None, # gate_A
|
None, # gate_A
|
||||||
None, # gate_B
|
None, # gate_B
|
||||||
None, # gate_scale
|
None, # gate_scale
|
||||||
up_proj.weight,
|
up_proj.weight,
|
||||||
up_proj.bias,
|
|
||||||
None, # up_quant
|
None, # up_quant
|
||||||
None, # up_A
|
None, # up_A
|
||||||
None, # up_B
|
None, # up_B
|
||||||
None, # up_scale
|
None, # up_scale
|
||||||
down_proj.weight,
|
down_proj.weight,
|
||||||
down_proj.bias,
|
|
||||||
None, # down_quant
|
None, # down_quant
|
||||||
None, # down_A
|
None, # down_A
|
||||||
None, # down_B
|
None, # down_B
|
||||||
@@ -250,19 +243,16 @@ def test_lora_mlp_with_adapters(
|
|||||||
output = LoRA_MLP.apply(
|
output = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
gate_proj.weight,
|
gate_proj.weight,
|
||||||
gate_proj.bias,
|
|
||||||
None,
|
None,
|
||||||
gate_A,
|
gate_A,
|
||||||
gate_B,
|
gate_B,
|
||||||
scale,
|
scale,
|
||||||
up_proj.weight,
|
up_proj.weight,
|
||||||
up_proj.bias,
|
|
||||||
None,
|
None,
|
||||||
up_A,
|
up_A,
|
||||||
up_B,
|
up_B,
|
||||||
scale,
|
scale,
|
||||||
down_proj.weight,
|
down_proj.weight,
|
||||||
down_proj.bias,
|
|
||||||
None,
|
None,
|
||||||
down_A,
|
down_A,
|
||||||
down_B,
|
down_B,
|
||||||
@@ -333,7 +323,6 @@ def test_lora_qkv(sample_tensors):
|
|||||||
X.requires_grad = True
|
X.requires_grad = True
|
||||||
|
|
||||||
# Test without LoRA adapters
|
# Test without LoRA adapters
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
Q1, K1, V1 = LoRA_QKV.apply(
|
Q1, K1, V1 = LoRA_QKV.apply(
|
||||||
X,
|
X,
|
||||||
q_weight,
|
q_weight,
|
||||||
@@ -341,19 +330,16 @@ def test_lora_qkv(sample_tensors):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
k_weight,
|
k_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
v_weight,
|
v_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -370,19 +356,16 @@ def test_lora_qkv(sample_tensors):
|
|||||||
X,
|
X,
|
||||||
q_weight,
|
q_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
q_A,
|
q_A,
|
||||||
q_B,
|
q_B,
|
||||||
scale,
|
scale,
|
||||||
k_weight,
|
k_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
k_A,
|
k_A,
|
||||||
k_B,
|
k_B,
|
||||||
scale,
|
scale,
|
||||||
v_weight,
|
v_weight,
|
||||||
None,
|
None,
|
||||||
None,
|
|
||||||
v_A,
|
v_A,
|
||||||
v_B,
|
v_B,
|
||||||
scale,
|
scale,
|
||||||
@@ -416,7 +399,6 @@ def test_lora_o(sample_tensors):
|
|||||||
"""Tests LoRA output projection"""
|
"""Tests LoRA output projection"""
|
||||||
X = sample_tensors["X"]
|
X = sample_tensors["X"]
|
||||||
W = sample_tensors["W"]
|
W = sample_tensors["W"]
|
||||||
b = sample_tensors["b"]
|
|
||||||
scale = sample_tensors["scale"]
|
scale = sample_tensors["scale"]
|
||||||
|
|
||||||
shapes = sample_tensors["shapes"]
|
shapes = sample_tensors["shapes"]
|
||||||
@@ -429,7 +411,7 @@ def test_lora_o(sample_tensors):
|
|||||||
|
|
||||||
# Test forward pass
|
# Test forward pass
|
||||||
X.requires_grad = True
|
X.requires_grad = True
|
||||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
output = LoRA_O.apply(X, W, None, A, B, scale)
|
||||||
|
|
||||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||||
|
|
||||||
@@ -443,7 +425,6 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
|||||||
"""Tests LoRA with quantized weights"""
|
"""Tests LoRA with quantized weights"""
|
||||||
X = sample_tensors["X"] # [batch, seq, hidden]
|
X = sample_tensors["X"] # [batch, seq, hidden]
|
||||||
W = sample_tensors["W"] # [out, hidden]
|
W = sample_tensors["W"] # [out, hidden]
|
||||||
b = sample_tensors["b"] # [out]
|
|
||||||
scale = 0.5
|
scale = 0.5
|
||||||
|
|
||||||
shapes = sample_tensors["shapes"]
|
shapes = sample_tensors["shapes"]
|
||||||
@@ -455,13 +436,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
|||||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
# Test matmul with quantization
|
# Test matmul with quantization
|
||||||
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
|
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
|
||||||
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||||
assert not torch.isnan(out).any()
|
assert not torch.isnan(out).any()
|
||||||
|
|
||||||
# Test with different batch sizes
|
# Test with different batch sizes
|
||||||
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
||||||
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
|
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
|
||||||
assert out2.shape == (4, 6, W.shape[0])
|
assert out2.shape == (4, 6, W.shape[0])
|
||||||
assert not torch.isnan(out2).any()
|
assert not torch.isnan(out2).any()
|
||||||
|
|
||||||
@@ -478,12 +459,11 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
|
|||||||
"""Tests various input shapes and dimensions"""
|
"""Tests various input shapes and dimensions"""
|
||||||
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
||||||
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
||||||
b = torch.randn(out, device="cuda", dtype=torch.float16)
|
|
||||||
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
||||||
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
||||||
scale = 0.5
|
scale = 0.5
|
||||||
|
|
||||||
result = matmul_lora(X, W, b, None, A, B, scale)
|
result = matmul_lora(X, W, None, A, B, scale)
|
||||||
assert result.shape == (batch, seq, out)
|
assert result.shape == (batch, seq, out)
|
||||||
|
|
||||||
|
|
||||||
@@ -491,7 +471,6 @@ def test_gradient_flow(sample_tensors):
|
|||||||
"""Tests gradient flow through LoRA layers"""
|
"""Tests gradient flow through LoRA layers"""
|
||||||
X = sample_tensors["X"].clone()
|
X = sample_tensors["X"].clone()
|
||||||
W = sample_tensors["W"].clone()
|
W = sample_tensors["W"].clone()
|
||||||
b = sample_tensors["b"].clone()
|
|
||||||
scale = sample_tensors["scale"]
|
scale = sample_tensors["scale"]
|
||||||
|
|
||||||
shapes = sample_tensors["shapes"]
|
shapes = sample_tensors["shapes"]
|
||||||
@@ -507,7 +486,7 @@ def test_gradient_flow(sample_tensors):
|
|||||||
B.requires_grad = True
|
B.requires_grad = True
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
out = matmul_lora(X, W, b, None, A, B, scale)
|
out = matmul_lora(X, W, None, A, B, scale)
|
||||||
loss = out.sum()
|
loss = out.sum()
|
||||||
|
|
||||||
# Backward pass
|
# Backward pass
|
||||||
|
|||||||
@@ -174,69 +174,6 @@ class TestFSDP2:
|
|||||||
|
|
||||||
verify_training_success(temp_dir)
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_lora_sft_kernels(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"val_set_size": 0.01,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
"split": "train[:10%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 2,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"fsdp_version": 2,
|
|
||||||
"fsdp_config": {
|
|
||||||
"offload_params": False,
|
|
||||||
"cpu_ram_efficient_loading": False,
|
|
||||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
|
||||||
"state_dict_type": "FULL_STATE_DICT",
|
|
||||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
"reshard_after_forward": True,
|
|
||||||
},
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"bf16": True,
|
|
||||||
"lora_mlp_kernel": True,
|
|
||||||
"lora_qkv_kernel": True,
|
|
||||||
"lora_o_kernel": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"--main-process-port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
verify_training_success(temp_dir)
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
def test_qlora_sft(self, temp_dir):
|
def test_qlora_sft(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -299,70 +236,6 @@ class TestFSDP2:
|
|||||||
|
|
||||||
verify_training_success(temp_dir)
|
verify_training_success(temp_dir)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_qlora_sft_kernels(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"val_set_size": 0.01,
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
"split": "train[:10%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 2,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"fsdp_version": 2,
|
|
||||||
"fsdp_config": {
|
|
||||||
"offload_params": False,
|
|
||||||
"cpu_ram_efficient_loading": False,
|
|
||||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
|
||||||
"state_dict_type": "FULL_STATE_DICT",
|
|
||||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
"reshard_after_forward": True,
|
|
||||||
},
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"bf16": True,
|
|
||||||
"lora_mlp_kernel": True,
|
|
||||||
"lora_qkv_kernel": True,
|
|
||||||
"lora_o_kernel": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--num-processes",
|
|
||||||
"2",
|
|
||||||
"--main-process-port",
|
|
||||||
f"{get_torch_dist_unique_port()}",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
verify_training_success(temp_dir)
|
|
||||||
|
|
||||||
@require_torch_2_7_0
|
@require_torch_2_7_0
|
||||||
def test_dpo_fft(self, temp_dir):
|
def test_dpo_fft(self, temp_dir):
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -10,11 +10,7 @@ from accelerate.test_utils import execute_subprocess_async
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.e2e.utils import (
|
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||||
check_tensorboard,
|
|
||||||
require_torch_2_7_0,
|
|
||||||
require_torch_lt_2_6_0,
|
|
||||||
)
|
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
@@ -143,71 +139,3 @@ class TestMultiGPURay:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_torch_2_7_0
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"gradient_accumulation_steps",
|
|
||||||
[1, 2],
|
|
||||||
)
|
|
||||||
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"sample_packing": True,
|
|
||||||
"pad_to_sequence_len": True,
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"val_set_size": 0.01,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "tatsu-lab/alpaca",
|
|
||||||
"type": "alpaca",
|
|
||||||
"split": "train[:10%]",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 2,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"fsdp_version": 2,
|
|
||||||
"fsdp_config": {
|
|
||||||
"offload_params": False,
|
|
||||||
"cpu_ram_efficient_loading": False,
|
|
||||||
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
|
||||||
"state_dict_type": "FULL_STATE_DICT",
|
|
||||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
|
||||||
"reshard_after_forward": True,
|
|
||||||
},
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# write cfg to yaml file
|
|
||||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
|
||||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
|
||||||
|
|
||||||
execute_subprocess_async(
|
|
||||||
[
|
|
||||||
"axolotl",
|
|
||||||
"train",
|
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
|
||||||
"--use-ray",
|
|
||||||
"--ray-num-workers",
|
|
||||||
"2",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Integration tests for FSDP Params4bit patches."""
|
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
|
||||||
apply_bnb_torch_function_patch,
|
|
||||||
patched_torch_function,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_params4bit():
|
|
||||||
"""Create a mock Params4bit instance with test attributes."""
|
|
||||||
mock_instance = Mock()
|
|
||||||
mock_instance.requires_grad = True
|
|
||||||
mock_instance.quant_state = "test_state"
|
|
||||||
mock_instance.blocksize = 128
|
|
||||||
mock_instance.compress_statistics = True
|
|
||||||
mock_instance.quant_type = "fp4"
|
|
||||||
mock_instance.quant_storage = "test_storage"
|
|
||||||
mock_instance.module = "test_module"
|
|
||||||
mock_instance.bnb_quantized = True
|
|
||||||
return mock_instance
|
|
||||||
|
|
||||||
|
|
||||||
class TestBnbTorchFunctionPatch:
|
|
||||||
"""Test the Params4bit.__torch_function__ patch."""
|
|
||||||
|
|
||||||
def test_apply_patch(self):
|
|
||||||
"""Test that the patch can be applied."""
|
|
||||||
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
|
|
||||||
apply_bnb_torch_function_patch()
|
|
||||||
assert hasattr(mock_cls, "__torch_function__")
|
|
||||||
assert isinstance(mock_cls.__torch_function__, classmethod)
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
|
|
||||||
"""Test that torch.chunk preserves Params4bit attributes."""
|
|
||||||
mock_cls = Mock()
|
|
||||||
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
|
|
||||||
|
|
||||||
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
|
|
||||||
result = patched_torch_function(
|
|
||||||
mock_cls,
|
|
||||||
torch.chunk,
|
|
||||||
(type(mock_params4bit),),
|
|
||||||
args=(mock_params4bit, 2),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(result, tuple)
|
|
||||||
assert len(result) == 2
|
|
||||||
|
|
||||||
# Check that Params4bit constructor was called with preserved attributes
|
|
||||||
assert mock_cls.call_count == 2
|
|
||||||
for call in mock_cls.call_args_list:
|
|
||||||
kwargs = call[1]
|
|
||||||
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
|
|
||||||
assert kwargs["quant_state"] == mock_params4bit.quant_state
|
|
||||||
assert kwargs["blocksize"] == mock_params4bit.blocksize
|
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
def test_other_functions_fallback(self, mock_params4bit):
|
|
||||||
"""Test that non-chunk/split functions use Parameter fallback."""
|
|
||||||
mock_cls = Mock()
|
|
||||||
fallback_result = torch.tensor([5, 6, 7])
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
|
|
||||||
) as mock_fallback:
|
|
||||||
result = patched_torch_function(
|
|
||||||
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should call Parameter.__torch_function__ and return its result
|
|
||||||
mock_fallback.assert_called_once()
|
|
||||||
assert result is fallback_result
|
|
||||||
mock_cls.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
class TestFSDPPatchIntegration:
|
|
||||||
"""Test FSDP patch integration."""
|
|
||||||
|
|
||||||
@pytest.mark.integration
|
|
||||||
def test_all_patches_together(self):
|
|
||||||
"""Test that all patches can be applied together."""
|
|
||||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
|
||||||
apply_init_sharded_param_patch,
|
|
||||||
apply_init_unsharded_param_patch,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store original methods before patching
|
|
||||||
original_torch_function = getattr(
|
|
||||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
|
||||||
)
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
original_init_sharded = FSDPParam._init_sharded_param
|
|
||||||
original_init_unsharded = FSDPParam.init_unsharded_param
|
|
||||||
|
|
||||||
# Apply patches
|
|
||||||
apply_bnb_torch_function_patch()
|
|
||||||
apply_init_sharded_param_patch()
|
|
||||||
apply_init_unsharded_param_patch()
|
|
||||||
|
|
||||||
# Verify patches were applied
|
|
||||||
current_torch_function = getattr(
|
|
||||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
|
||||||
)
|
|
||||||
if original_torch_function is not None:
|
|
||||||
assert (
|
|
||||||
current_torch_function != original_torch_function
|
|
||||||
), "Params4bit.__torch_function__ was not patched"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
current_torch_function is not None
|
|
||||||
), "Params4bit.__torch_function__ was not added"
|
|
||||||
|
|
||||||
# Check that FSDP methods were patched
|
|
||||||
assert (
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
FSDPParam._init_sharded_param
|
|
||||||
!= original_init_sharded
|
|
||||||
), "_init_sharded_param was not patched"
|
|
||||||
assert (
|
|
||||||
FSDPParam.init_unsharded_param != original_init_unsharded
|
|
||||||
), "init_unsharded_param was not patched"
|
|
||||||
@@ -29,6 +29,7 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
|
"flash_attn_fuse_qkv": True,
|
||||||
"flash_attn_fuse_mlp": True,
|
"flash_attn_fuse_mlp": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from .utils import (
|
|||||||
check_model_output_exists,
|
check_model_output_exists,
|
||||||
require_torch_2_5_1,
|
require_torch_2_5_1,
|
||||||
require_torch_2_6_0,
|
require_torch_2_6_0,
|
||||||
require_torch_2_7_0,
|
|
||||||
with_temp_dir,
|
with_temp_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -161,49 +160,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
@require_torch_2_7_0
|
|
||||||
def test_dion(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"model_type": "AutoModelForCausalLM",
|
|
||||||
"tokenizer_type": "AutoTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "dion",
|
|
||||||
"dion_lr": 0.01,
|
|
||||||
"dion_momentum": 0.95,
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"weight_decay": 0.01,
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
|
|
||||||
|
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
"""Unit tests for trainer loss calc monkeypatch."""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
|
||||||
check_evaluation_loop_is_fsdp2_patchable,
|
|
||||||
check_evaluation_loop_is_patchable,
|
|
||||||
check_maybe_log_save_evaluate_is_patchable,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTrainerLossCalc(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Unit test class for trainer loss calc monkeypatch
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_trainer_loss_calc_is_patchable(self):
|
|
||||||
"""
|
|
||||||
Test that the upstream transformers code is still patchable. This will fail if
|
|
||||||
the patched code changes upstream.
|
|
||||||
"""
|
|
||||||
assert check_evaluation_loop_is_patchable()
|
|
||||||
assert check_evaluation_loop_is_fsdp2_patchable()
|
|
||||||
assert check_maybe_log_save_evaluate_is_patchable()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -9,7 +9,6 @@ from transformers.utils.import_utils import is_torch_mps_available
|
|||||||
|
|
||||||
from axolotl.loaders import ModelLoader
|
from axolotl.loaders import ModelLoader
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import _get_parallel_config_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class TestModelsUtils:
|
class TestModelsUtils:
|
||||||
@@ -194,13 +193,15 @@ class TestModelsUtils:
|
|||||||
is_fsdp,
|
is_fsdp,
|
||||||
expected,
|
expected,
|
||||||
):
|
):
|
||||||
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
|
res = (
|
||||||
world_size,
|
ModelLoader._get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||||
tensor_parallel_size,
|
world_size,
|
||||||
context_parallel_size,
|
tensor_parallel_size,
|
||||||
dp_shard_size,
|
context_parallel_size,
|
||||||
dp_replicate_size,
|
dp_shard_size,
|
||||||
is_fsdp,
|
dp_replicate_size,
|
||||||
|
is_fsdp,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if expected[0] > 1:
|
if expected[0] > 1:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user