Compare commits
51 Commits
chat-templ
...
v0.12.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cf22ae23b | ||
|
|
c10eb811fa | ||
|
|
0eef385b1a | ||
|
|
ecbe8b2b61 | ||
|
|
130ef7c51a | ||
|
|
d1de6f5f3d | ||
|
|
48b7ae1677 | ||
|
|
506e3a3907 | ||
|
|
09145de8fa | ||
|
|
e0a2523a3b | ||
|
|
3d45620008 | ||
|
|
ce20e838b5 | ||
|
|
d4d84d48af | ||
|
|
9b12c05660 | ||
|
|
686933194e | ||
|
|
d12b461d19 | ||
|
|
d6b81b3683 | ||
|
|
05f1b4b2e8 | ||
|
|
7cfc80ec77 | ||
|
|
0da6a95efa | ||
|
|
2c8497e489 | ||
|
|
f70d4de8c7 | ||
|
|
0ae06d756d | ||
|
|
2974670bf8 | ||
|
|
50f2b94d50 | ||
|
|
eb2c87b525 | ||
|
|
4db7f023c6 | ||
|
|
4273d5cf7e | ||
|
|
c5e5aba547 | ||
|
|
9d5c95db6f | ||
|
|
ca796fb56e | ||
|
|
597953bef0 | ||
|
|
39fbd3b2b5 | ||
|
|
46dfacf255 | ||
|
|
4bce713b39 | ||
|
|
d09290f2f4 | ||
|
|
e442ff22aa | ||
|
|
ba3dba3e4f | ||
|
|
97e86c6d47 | ||
|
|
784f8c0e95 | ||
|
|
e3177c3210 | ||
|
|
70faea331f | ||
|
|
8021c718ce | ||
|
|
42f5e6f9e9 | ||
|
|
ab49d16e34 | ||
|
|
33d094721c | ||
|
|
a54c1be972 | ||
|
|
5691992d34 | ||
|
|
e758343cac | ||
|
|
deac7b18a1 | ||
|
|
10946afae7 |
7
.github/CONTRIBUTING.md
vendored
7
.github/CONTRIBUTING.md
vendored
@@ -57,6 +57,13 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
|
||||
5. Push your branch to your fork on GitHub.
|
||||
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
||||
|
||||
#### Skipping CI Checks
|
||||
|
||||
You can skip certain CI checks by including specific keywords in your commit messages:
|
||||
|
||||
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
|
||||
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
|
||||
|
||||
## Style Guidelines
|
||||
|
||||
### Code Style
|
||||
|
||||
27
.github/workflows/base.yml
vendored
27
.github/workflows/base.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.6.3
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
@@ -64,9 +64,16 @@ jobs:
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base-nightly"
|
||||
dockerfile: "Dockerfile-base"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.11"
|
||||
# pytorch: nightly
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base-nightly"
|
||||
# # "next" is for release candidates of pytorch
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
@@ -122,6 +129,13 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -129,6 +143,13 @@ jobs:
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.8.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
23
.github/workflows/main.yml
vendored
23
.github/workflows/main.yml
vendored
@@ -24,12 +24,13 @@ jobs:
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
@@ -97,6 +98,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
@@ -150,6 +157,18 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
is_latest:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras: vllm
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
49
.github/workflows/tests.yml
vendored
49
.github/workflows/tests.yml
vendored
@@ -105,7 +105,8 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||
|
||||
@@ -179,21 +180,52 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||
pytest -v --durations=10 tests/patched/
|
||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||
pytest -v --durations=10 tests/cli/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
gate-skip-e2e:
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
skip: ${{ steps.compute.outputs.skip }}
|
||||
steps:
|
||||
- uses: actions/github-script@v7
|
||||
id: compute
|
||||
with:
|
||||
script: |
|
||||
const token = /\[skip-e2e\]/i;
|
||||
let msg = '';
|
||||
if (context.eventName === 'push') {
|
||||
msg = context.payload.head_commit?.message || '';
|
||||
} else if (context.eventName === 'pull_request') {
|
||||
const { owner, repo } = context.repo;
|
||||
const prNumber = context.payload.pull_request.number;
|
||||
const commits = await github.paginate(
|
||||
github.rest.pulls.listCommits,
|
||||
{ owner, repo, pull_number: prNumber, per_page: 100 }
|
||||
);
|
||||
msg = commits.at(-1)?.commit?.message || '';
|
||||
}
|
||||
const title = context.payload.pull_request?.title || '';
|
||||
const body = context.payload.pull_request?.body || '';
|
||||
const skip = token.test(msg) || token.test(title) || token.test(body);
|
||||
core.setOutput('skip', String(skip));
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
# Run this job first as a gate for running the remainder of the test matrix
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||
if: >
|
||||
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -239,13 +271,16 @@ jobs:
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
||||
if: >
|
||||
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
# Only run the remainder of the matrix if the first e2e check passed;
|
||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
@@ -23,11 +23,11 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.7
|
||||
rev: v3.3.8
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.0
|
||||
rev: v1.17.1
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -185,7 +185,6 @@ datasets:
|
||||
| `flash_attention` | `false` | Use flash attention |
|
||||
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
||||
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
||||
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
|
||||
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
||||
| `sdp_attention` | `false` | Use scaled dot product |
|
||||
| `s2_attention` | `false` | Use shifted sparse attention |
|
||||
|
||||
@@ -296,7 +296,6 @@
|
||||
# flash_attention:
|
||||
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# # Whether to use scaled-dot-product attention
|
||||
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
@@ -541,7 +540,6 @@ xformers_attention: ${XFORMERS_ATTENTION}
|
||||
flash_attention: ${FLASH_ATTENTION}
|
||||
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
||||
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
||||
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
|
||||
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
||||
sdp_attention: ${SDP_ATTENTION}
|
||||
s2_attention: ${S2_ATTENTION}
|
||||
|
||||
10
CITATION.cff
Normal file
10
CITATION.cff
Normal file
@@ -0,0 +1,10 @@
|
||||
cff-version: 1.2.0
|
||||
type: software
|
||||
title: "Axolotl: Post-Training for AI Models"
|
||||
message: "If you use this software, please cite it as below."
|
||||
authors:
|
||||
- name: "Axolotl maintainers and contributors"
|
||||
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
|
||||
url: "https://axolotl.ai/"
|
||||
license: Apache-2.0
|
||||
date-released: "2023-05-30"
|
||||
33
README.md
33
README.md
@@ -25,17 +25,28 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
|
||||
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/07:
|
||||
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||
|
||||
</details>
|
||||
|
||||
## ✨ Overview
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
@@ -138,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
|
||||
|
||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||
|
||||
## 📝 Citing Axolotl
|
||||
|
||||
If you use Axolotl in your research or projects, please cite it as follows:
|
||||
|
||||
```bibtex
|
||||
@software{axolotl,
|
||||
title = {Axolotl: Post-Training for AI Models},
|
||||
author = {{Axolotl maintainers and contributors}},
|
||||
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
||||
license = {Apache-2.0},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 📜 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
10
TODO.md
10
TODO.md
@@ -1,10 +0,0 @@
|
||||
# todo list
|
||||
|
||||
- [] Validation of parameters for combinations that won't work
|
||||
|
||||
|
||||
|
||||
## things that are known not to work
|
||||
|
||||
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
|
||||
- adamw_bnb_8bit doesn't play well with FSDP offload
|
||||
@@ -274,6 +274,7 @@ website:
|
||||
- docs/dataset_preprocessing.qmd
|
||||
- docs/multipack.qmd
|
||||
- docs/mixed_precision.qmd
|
||||
- docs/optimizers.qmd
|
||||
|
||||
- section: "Advanced Features"
|
||||
contents:
|
||||
|
||||
@@ -37,7 +37,7 @@ WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||
python3 -m pip cache purge
|
||||
|
||||
|
||||
@@ -212,10 +212,11 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
Example config for Llama4:
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: ...
|
||||
- path: Nanobit/text-tools-2k-test
|
||||
type: chat_template
|
||||
# field_tools: tools # default is `tools`
|
||||
```
|
||||
|
||||
@@ -13,10 +13,13 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
|
||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||
sample_packing: false # not yet supported with multimodal
|
||||
|
||||
chat_template: # see in next section
|
||||
chat_template: # see in next section if specified
|
||||
|
||||
# example dataset
|
||||
datasets:
|
||||
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
chat_template: mistral_v7_tekken
|
||||
```
|
||||
|
||||
### Voxtral {#sec-voxtral}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Voxtral-Mini-3B-2507
|
||||
```
|
||||
|
||||
### Gemma-3 {#sec-gemma-3}
|
||||
|
||||
::: {.callout-tip}
|
||||
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### SmolVLM2 {#sec-smolvlm2}
|
||||
|
||||
::: {.callout-tip}
|
||||
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
|
||||
```
|
||||
|
||||
### LFM2-VL {#sec-lfm2-vl}
|
||||
|
||||
::: {.callout-warning}
|
||||
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
||||
:::
|
||||
|
||||
```yaml
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
```
|
||||
|
||||
## Dataset Format
|
||||
|
||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
|
||||
|
||||
:::
|
||||
|
||||
### Video
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
This is not well tested at the moment. We welcome contributors!
|
||||
|
||||
:::
|
||||
|
||||
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
|
||||
|
||||
- `"path": "/path/to/video.mp4"`
|
||||
- `"url": "https://example.com/video.mp4"`
|
||||
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
|
||||
|
||||
### Example
|
||||
|
||||
Here is an example of a multi-modal dataset:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# N-D Parallelism
|
||||
---
|
||||
title: "N-D Parallelism (Beta)"
|
||||
---
|
||||
|
||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||
|
||||
@@ -71,6 +73,10 @@ Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size
|
||||
|
||||
## Examples
|
||||
|
||||
::: {.callout-tip}
|
||||
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
|
||||
:::
|
||||
|
||||
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
||||
- You want FSDP within each node and DDP across nodes.
|
||||
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
||||
@@ -95,7 +101,7 @@ This matrix describes how different parallelism methods can be combined in Axolo
|
||||
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
||||
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
||||
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
|
||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
|
||||
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
||||
|
||||
- `tp_size` refers to `tensor_parallel_size`
|
||||
|
||||
129
docs/optimizers.qmd
Normal file
129
docs/optimizers.qmd
Normal file
@@ -0,0 +1,129 @@
|
||||
---
|
||||
title: Optimizers
|
||||
description: Configuring optimizers
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
|
||||
|
||||
Here is a list of optimizers supported by transformers as of `v4.54.0`:
|
||||
|
||||
- `adamw_torch`
|
||||
- `adamw_torch_fused`
|
||||
- `adamw_torch_xla`
|
||||
- `adamw_torch_npu_fused`
|
||||
- `adamw_apex_fused`
|
||||
- `adafactor`
|
||||
- `adamw_anyprecision`
|
||||
- `adamw_torch_4bit`
|
||||
- `adamw_torch_8bit`
|
||||
- `ademamix`
|
||||
- `sgd`
|
||||
- `adagrad`
|
||||
- `adamw_bnb_8bit`
|
||||
- `adamw_8bit` # alias for adamw_bnb_8bit
|
||||
- `ademamix_8bit`
|
||||
- `lion_8bit`
|
||||
- `lion_32bit`
|
||||
- `paged_adamw_32bit`
|
||||
- `paged_adamw_8bit`
|
||||
- `paged_ademamix_32bit`
|
||||
- `paged_ademamix_8bit`
|
||||
- `paged_lion_32bit`
|
||||
- `paged_lion_8bit`
|
||||
- `rmsprop`
|
||||
- `rmsprop_bnb`
|
||||
- `rmsprop_bnb_8bit`
|
||||
- `rmsprop_bnb_32bit`
|
||||
- `galore_adamw`
|
||||
- `galore_adamw_8bit`
|
||||
- `galore_adafactor`
|
||||
- `galore_adamw_layerwise`
|
||||
- `galore_adamw_8bit_layerwise`
|
||||
- `galore_adafactor_layerwise`
|
||||
- `lomo`
|
||||
- `adalomo`
|
||||
- `grokadamw`
|
||||
- `schedule_free_radam`
|
||||
- `schedule_free_adamw`
|
||||
- `schedule_free_sgd`
|
||||
- `apollo_adamw`
|
||||
- `apollo_adamw_layerwise`
|
||||
- `stable_adamw`
|
||||
|
||||
|
||||
## Custom Optimizers
|
||||
|
||||
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
|
||||
|
||||
### optimi_adamw
|
||||
|
||||
```yaml
|
||||
optimizer: optimi_adamw
|
||||
```
|
||||
|
||||
### ao_adamw_4bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_4bit`.
|
||||
|
||||
### ao_adamw_8bit
|
||||
|
||||
Deprecated: Please use `adamw_torch_8bit`.
|
||||
|
||||
### ao_adamw_fp8
|
||||
|
||||
|
||||
```yaml
|
||||
optimizer: ao_adamw_fp8
|
||||
```
|
||||
|
||||
### adopt_adamw
|
||||
|
||||
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
|
||||
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
|
||||
|
||||
```yaml
|
||||
optimizer: adopt_adamw
|
||||
```
|
||||
|
||||
### came_pytorch
|
||||
|
||||
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
|
||||
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
|
||||
|
||||
```yaml
|
||||
optimizer: came_pytorch
|
||||
|
||||
# optional args (defaults below)
|
||||
adam_beta1: 0.9
|
||||
adam_beta2: 0.999
|
||||
adam_beta3: 0.9999
|
||||
adam_epsilon: 1e-30
|
||||
adam_epsilon2: 1e-16
|
||||
```
|
||||
|
||||
### muon
|
||||
|
||||
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
|
||||
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
|
||||
|
||||
```yaml
|
||||
optimizer: muon
|
||||
```
|
||||
|
||||
### dion
|
||||
|
||||
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
||||
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
||||
|
||||
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
|
||||
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
|
||||
Note: Implementation written for PyTorch 2.7+ for DTensor
|
||||
|
||||
```yaml
|
||||
optimizer: dion
|
||||
dion_lr: 0.01
|
||||
dion_momentum: 0.95
|
||||
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
|
||||
```
|
||||
58
examples/LiquidAI/README.md
Normal file
58
examples/LiquidAI/README.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
|
||||
|
||||
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
|
||||
|
||||
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
|
||||
|
||||
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run one of the finetuning examples below.
|
||||
|
||||
**LFM2**
|
||||
```bash
|
||||
# FFT SFT (1x48GB @ 25GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
|
||||
```
|
||||
|
||||
**LFM2-VL**
|
||||
```bash
|
||||
# LoRA SFT (1x48GB @ 2.7GiB)
|
||||
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||
```bash
|
||||
pip uninstall -y causal-conv1d
|
||||
```
|
||||
|
||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- **Dataset Formats**:
|
||||
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
|
||||
|
||||
chunked_cross_entropy: true
|
||||
|
||||
chat_template: tokenizer_default
|
||||
eot_tokens:
|
||||
- "<|im_end|>"
|
||||
datasets:
|
||||
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: LiquidAI/LFM2-VL-450M
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForImageTextToText
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
53
examples/arcee/README.md
Normal file
53
examples/arcee/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# Finetune ArceeAI's AFM with Axolotl
|
||||
|
||||
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/arcee/afm-4.5b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 7.8GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: arcee-ai/AFM-4.5B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -47,7 +47,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -10,17 +10,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
52
examples/distributed-parallel/README.md
Normal file
52
examples/distributed-parallel/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# ND Parallelism Examples
|
||||
|
||||
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Run the command below:
|
||||
|
||||
```bash
|
||||
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
|
||||
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
|
||||
|
||||
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
|
||||
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
|
||||
```
|
||||
|
||||
## Example Configurations
|
||||
|
||||
### Single Node (8 GPUs)
|
||||
|
||||
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
|
||||
- Uses all 3 parallelism dimensions on a single node
|
||||
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
|
||||
|
||||
```yaml
|
||||
dp_shard_size: 2 # FSDP across 2 GPUs
|
||||
tensor_parallel_size: 2 # TP across 2 GPUs
|
||||
context_parallel_size: 2 # CP across 2 GPUs
|
||||
# Total: 2 × 2 × 2 = 8 GPUs
|
||||
```
|
||||
|
||||
### Multi-Node
|
||||
|
||||
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
|
||||
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
|
||||
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
|
||||
|
||||
```yaml
|
||||
dp_shard_size: 4 # FSDP within each 4-GPU group
|
||||
tensor_parallel_size: 2 # TP within each node
|
||||
dp_replicate_size: 2 # DDP across 2 groups
|
||||
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
|
||||
```
|
||||
|
||||
## Learn More
|
||||
|
||||
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
|
||||
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
|
||||
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 4
|
||||
dp_replicate_size: 2
|
||||
tensor_parallel_size: 2
|
||||
# context_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
base_model: Qwen/Qwen3-8B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
dp_shard_size: 2
|
||||
# dp_replicate_size: 1
|
||||
context_parallel_size: 2
|
||||
tensor_parallel_size: 2
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/ndp-out/
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
flash_attention: true
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1 # must be 1 when using context parallel
|
||||
num_epochs: 2
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-6
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
@@ -4,17 +4,14 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||
|
||||
105
examples/gpt-oss/README.md
Normal file
105
examples/gpt-oss/README.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Finetune OpenAI's GPT-OSS with Axolotl
|
||||
|
||||
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||
|
||||
```bash
|
||||
# LoRA SFT linear layers (1x48GB @ ~44GiB)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
|
||||
|
||||
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
|
||||
|
||||
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
|
||||
```
|
||||
|
||||
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
|
||||
|
||||
### Training 120B
|
||||
|
||||
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
|
||||
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
|
||||
|
||||
```bash
|
||||
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
||||
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
```
|
||||
|
||||
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
||||
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
||||
|
||||
```bash
|
||||
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
|
||||
```
|
||||
|
||||
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
|
||||
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
|
||||
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
|
||||
weights to `{output_dir}/merged`.
|
||||
|
||||
```bash
|
||||
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||
```
|
||||
|
||||
|
||||
### Inferencing your fine-tuned model
|
||||
|
||||
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
||||
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
||||
|
||||
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
||||
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
||||
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
|
||||
```
|
||||
|
||||
### Tool use
|
||||
|
||||
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
||||
|
||||
Here is an example dataset config:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: Nanobit/text-tools-2k-test
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
|
||||
|
||||
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
|
||||
|
||||
### TIPS
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
68
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
68
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
|
||||
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
|
||||
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
|
||||
|
||||
use_kernels: false
|
||||
|
||||
dp_shard_size: 16 # requires 2x8xH100 nodes
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
cpu_ram_efficient_loading: true
|
||||
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
@@ -0,0 +1,58 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: false
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
# choose the zero3 configuration that best fits your system capabilities
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,68 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
|
||||
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
|
||||
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`
|
||||
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: false
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.03
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||
reshard_after_forward: true
|
||||
# cpu_ram_efficient_loading: true
|
||||
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: openai/gpt-oss-20b
|
||||
use_kernels: true
|
||||
model_quantization_config: Mxfp4Config
|
||||
model_quantization_config_kwargs:
|
||||
dequantize: true
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/Multilingual-Thinking
|
||||
type: chat_template
|
||||
field_thinking: thinking
|
||||
template_thinking_key: thinking
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/gpt-oss-out/
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||
lora_target_linear: true
|
||||
|
||||
# TODO: not supported for now, see peft#2710
|
||||
#lora_target_parameters: # target the experts in the last two layers
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: constant_with_warmup
|
||||
learning_rate: 2e-4
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
flash_attention: true
|
||||
attn_implementation: kernels-community/vllm-flash-attn3
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
|
||||
logging_steps: 1
|
||||
saves_per_epoch: 1
|
||||
warmup_ratio: 0.1
|
||||
|
||||
special_tokens:
|
||||
eot_tokens:
|
||||
- "<|end|>"
|
||||
@@ -1,7 +0,0 @@
|
||||
# Liquid Foundation Models 2
|
||||
|
||||
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
|
||||
|
||||
```bash
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
@@ -45,7 +45,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -49,7 +49,6 @@ logging_steps: 1
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
|
||||
@@ -8,17 +8,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Run the finetuning example:
|
||||
|
||||
@@ -27,7 +27,6 @@ sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
66
examples/slurm/README.md
Normal file
66
examples/slurm/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# SLURM Multi-Node Training
|
||||
|
||||
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Access to a SLURM cluster with GPU nodes
|
||||
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Usage
|
||||
|
||||
### Standard SLURM Clusters
|
||||
|
||||
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
|
||||
2. Place your Axolotl config file (`train.yaml`) in the same directory.
|
||||
3. Set the appropriate environment variables for the job:
|
||||
```bash
|
||||
export HF_TOKEN="your-huggingface-token"
|
||||
|
||||
# metric tracking
|
||||
# export WANDB_API_KEY="your-wandb-api-key"
|
||||
# ...
|
||||
```
|
||||
4. Submit the job:
|
||||
```bash
|
||||
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
|
||||
```
|
||||
|
||||
Where:
|
||||
- `NUM_NODES`: Number of nodes to use
|
||||
- `NUM_TRAINERS`: GPUs per node (typically 8)
|
||||
- `PRIMARY_ADDR`: Hostname/IP of the master node
|
||||
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
|
||||
|
||||
5. (Optional) Run other slurm commands:
|
||||
```bash
|
||||
# check job info
|
||||
scontrol show job axolotl-cli
|
||||
|
||||
# check job queue
|
||||
squeue
|
||||
|
||||
# check cluster status
|
||||
sinfo
|
||||
```
|
||||
|
||||
### RunPod Instant Clusters
|
||||
|
||||
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
|
||||
|
||||
1. **Deploy a SLURM Cluster**:
|
||||
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
|
||||
- Click "Create a Cluster"
|
||||
- Choose your GPU type, node count, and region
|
||||
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
|
||||
- Deploy the cluster
|
||||
|
||||
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
|
||||
|
||||
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
|
||||
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)
|
||||
20
examples/slurm/axolotl.slurm
Normal file
20
examples/slurm/axolotl.slurm
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
|
||||
# export HF_TOKEN="..."
|
||||
# export WANDB_API_KEY="..."
|
||||
#
|
||||
|
||||
# ---------- SBATCH commands ---------- #
|
||||
#SBATCH --job-name=axolotl-slurm-multinode
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --nodes=$NUM_NODES
|
||||
#SBATCH --gpus-per-task=8
|
||||
#SBATCH --cpus-per-task=128
|
||||
|
||||
export TORCH_DIST_INIT_BARRIER=0
|
||||
|
||||
srun axolotl preprocess train.yaml
|
||||
|
||||
srun axolotl train train.yaml --launcher torchrun -- \
|
||||
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
|
||||
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"
|
||||
49
examples/smolvlm2/README.md
Normal file
49
examples/smolvlm2/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Finetune SmolVLM2 with Axolotl
|
||||
|
||||
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
|
||||
|
||||
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
|
||||
|
||||
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from pip:
|
||||
```bash
|
||||
# Ensure you have a compatible version of Pytorch installed
|
||||
pip3 install packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Install an extra dependency:
|
||||
|
||||
```bash
|
||||
pip3 install num2words==0.5.14
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# LoRA SFT (1x48GB @ 6.8GiB)
|
||||
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
|
||||
```
|
||||
|
||||
## TIPS
|
||||
|
||||
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
@@ -0,0 +1,56 @@
|
||||
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||
trust_remote_code: true
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# these 3 lines are needed for now to handle vision chat templates w images
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
eager_attention:
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -6,17 +6,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
Here is an example of how to install from pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||
```
|
||||
|
||||
2. Please install the below.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.46.0
|
||||
triton>=3.0.0
|
||||
bitsandbytes==0.47.0
|
||||
# triton 3.4.0 is not compatible with CCE
|
||||
triton>=3.0.0,<3.4.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
@@ -12,19 +13,21 @@ liger-kernel==0.6.1
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub>=0.33.0
|
||||
peft==0.16.0
|
||||
transformers==4.54.1
|
||||
peft==0.17.0
|
||||
transformers==4.55.2
|
||||
tokenizers>=0.21.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
||||
accelerate==1.10.0
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.20.0
|
||||
trl==0.21.0
|
||||
hf_xet==1.1.5
|
||||
kernels==0.9.0
|
||||
trackio
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==5.23.3
|
||||
gradio==5.41.1
|
||||
|
||||
modal==1.0.2
|
||||
pydantic==2.10.6
|
||||
@@ -66,6 +69,6 @@ torchao==0.12.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
axolotl-contribs-mit==0.0.5
|
||||
|
||||
mistral-common==1.8.3
|
||||
|
||||
@@ -44,8 +44,13 @@ add_keys_to_authorized() {
|
||||
chmod 700 -R ~/.ssh
|
||||
}
|
||||
|
||||
# Set SSH port
|
||||
if [ ! -z "$SSH_PORT" ]; then
|
||||
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
|
||||
fi
|
||||
|
||||
if [[ $PUBLIC_KEY ]]; then
|
||||
# runpod
|
||||
# runpod, prime intellect
|
||||
add_keys_to_authorized "$PUBLIC_KEY"
|
||||
# Start the SSH service in the background
|
||||
service ssh start
|
||||
@@ -76,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||
fi
|
||||
|
||||
# start the runpod slurm init
|
||||
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
|
||||
|
||||
if [[ -f "$SLURM_INIT" ]]; then
|
||||
echo "[entrypoint] running $SLURM_INIT..."
|
||||
bash "$SLURM_INIT"
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
exec "$@"
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
|
||||
)
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.12.0.dev"
|
||||
__version__ = "0.12.2"
|
||||
|
||||
@@ -40,6 +40,12 @@ class VllmServeCliArgs:
|
||||
default=None,
|
||||
metadata={"help": "Number of tensor parallel workers to use."},
|
||||
)
|
||||
data_parallel_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
|
||||
},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default=None, # nosec B104
|
||||
metadata={"help": "Host address to run the server on."},
|
||||
|
||||
@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def plugin_set_cfg(cfg: DictDefault):
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.cfg = cfg
|
||||
# now that we have the finalized cfg, register the plugins individually
|
||||
for plugin in plugin_manager.plugins.values():
|
||||
plugin.register(cfg)
|
||||
|
||||
|
||||
def load_cfg(
|
||||
|
||||
@@ -123,9 +123,10 @@ def train(
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# Process each configuration
|
||||
for cfg_file in generate_config_files(config, sweep):
|
||||
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||
try:
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
||||
use_exec = is_group is not True
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
|
||||
@@ -10,6 +10,7 @@ import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
@@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -143,7 +145,6 @@ def merge_fsdp_weights(
|
||||
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
||||
"""
|
||||
checkpoint_dir_ = Path(checkpoint_dir)
|
||||
from accelerate.state import PartialState
|
||||
|
||||
if not is_torch_version(">=", "2.3.0"):
|
||||
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||
@@ -180,7 +181,6 @@ def merge_fsdp_weights(
|
||||
if remove_checkpoint_dir:
|
||||
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
||||
shutil.rmtree(checkpoint_dir_)
|
||||
state.wait_for_everyone()
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
@@ -195,11 +195,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
|
||||
if checkpoint_dir:
|
||||
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||
if not fsdp_dir.exists():
|
||||
raise ValueError(
|
||||
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
|
||||
)
|
||||
|
||||
output_path = str(Path(parsed_cfg.output_dir) / "merged")
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
|
||||
output_path=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
LOG.info(
|
||||
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.info(
|
||||
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
@@ -64,10 +65,18 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
"""Generate list of configuration files to process."""
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||
"""
|
||||
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
|
||||
whether this is a group of configurations (i.e., a sweep).
|
||||
|
||||
Args:
|
||||
config: Base configuration file
|
||||
sweep: Sweep configuration file
|
||||
"""
|
||||
|
||||
if not sweep:
|
||||
yield config
|
||||
yield config, False
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
@@ -78,6 +87,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
is_group = len(permutations) > 1
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
@@ -88,7 +98,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name
|
||||
yield temp_file.name, is_group
|
||||
|
||||
|
||||
def launch_training(
|
||||
@@ -97,6 +107,7 @@ def launch_training(
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -105,11 +116,14 @@ def launch_training(
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
elif launcher is None:
|
||||
# handle ray train launch
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
@@ -136,7 +150,10 @@ def _launch_cloud_training(
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -161,11 +178,20 @@ def _launch_accelerate_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
cfg_file: str,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
use_exec: bool = False,
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
@@ -178,7 +204,13 @@ def _launch_torchrun_training(
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
if use_exec:
|
||||
# make sure to flush stdout and stderr before replacing the process
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||
else:
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import trl
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
@@ -42,13 +40,17 @@ def do_vllm_serve(
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
@@ -81,63 +83,3 @@ def do_vllm_serve(
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
|
||||
def patch_vllm_worker():
|
||||
from multiprocessing.connection import Connection
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
def llm_worker(
|
||||
script_args: AxolotlScriptArguments,
|
||||
data_parallel_rank: int,
|
||||
master_port: int,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
# Set required environment variables for DP to work with vLLM
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
revision=script_args.revision,
|
||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
enable_reasoning=script_args.enable_reasoning,
|
||||
reasoning_parser=script_args.reasoning_parser,
|
||||
)
|
||||
|
||||
# Send ready signal to parent process
|
||||
connection.send({"status": "ready"})
|
||||
|
||||
while True:
|
||||
# Wait for commands from the parent process
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
llm.collective_rpc(method="close_communicator")
|
||||
break
|
||||
|
||||
# Handle commands
|
||||
if command["type"] in ["call", "fire_and_forget"]:
|
||||
method_name = command["method"]
|
||||
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
trl.scripts.vllm_serve.llm_worker = llm_worker
|
||||
|
||||
@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -24,12 +24,10 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
@@ -40,6 +38,7 @@ from axolotl.utils.callbacks import (
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.distributed import build_parallelism_config
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -267,27 +266,24 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "dion":
|
||||
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
|
||||
DionOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = DionOptimizerFactory
|
||||
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
_, device_mesh = build_parallelism_config(self.cfg)
|
||||
if device_mesh is not None:
|
||||
optimizer_kwargs["device_mesh"] = device_mesh
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
@@ -433,30 +429,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
partial_state = PartialState()
|
||||
has_pc_attr = (
|
||||
hasattr(partial_state, "parallelism_config")
|
||||
and partial_state.parallelism_config
|
||||
)
|
||||
has_pc_key = (
|
||||
"parallelism_config"
|
||||
in partial_state._shared_state # pylint: disable=protected-access
|
||||
and partial_state._shared_state[ # pylint: disable=protected-access
|
||||
"parallelism_config"
|
||||
]
|
||||
)
|
||||
use_configured_state = has_pc_attr or has_pc_key
|
||||
if self.cfg.accelerator_config:
|
||||
use_configured_state = self.cfg.accelerator_config.pop(
|
||||
"use_configured_state", use_configured_state
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||
**self.cfg.accelerator_config
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||
use_configured_state=use_configured_state,
|
||||
)
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.activation_offloading is True:
|
||||
@@ -516,10 +494,20 @@ class TrainerBuilderBase(abc.ABC):
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"seed",
|
||||
"dion_momentum",
|
||||
"dion_rank_fraction",
|
||||
"dion_rank_multiple_of",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
|
||||
arg_map = {
|
||||
"dion_learning_rate": "dion_lr",
|
||||
}
|
||||
for kwarg, cfg_arg in arg_map.items():
|
||||
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
|
||||
|
||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||
training_args_kwargs["average_tokens_across_devices"] = False
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -136,6 +137,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlRewardTrainer
|
||||
if self.cfg.process_reward_model:
|
||||
return AxolotlPRMTrainer
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
return trainer_cls
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -350,7 +363,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||
self.cfg.sequence_len / multiple
|
||||
)
|
||||
else:
|
||||
elif self.cfg.pad_to_sequence_len is None:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
|
||||
@@ -10,8 +10,11 @@ from functools import partial, wraps
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
import safetensors
|
||||
import torch
|
||||
from accelerate.state import AcceleratorState
|
||||
from datasets import Dataset
|
||||
from peft import PeftModel
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
DataLoader,
|
||||
@@ -19,8 +22,10 @@ from torch.utils.data import (
|
||||
Sampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import Trainer
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -515,7 +520,18 @@ class AxolotlTrainer(
|
||||
|
||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||
def create_accelerator_and_postprocess(self):
|
||||
res = super().create_accelerator_and_postprocess()
|
||||
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
||||
accelerator_config = self.args.accelerator_config.to_dict()
|
||||
use_configured_state = accelerator_config.get("use_configured_state", False)
|
||||
if not use_configured_state:
|
||||
AcceleratorState._reset_state( # pylint: disable=protected-access
|
||||
reset_partial_state=True
|
||||
)
|
||||
|
||||
super().create_accelerator_and_postprocess()
|
||||
|
||||
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||
|
||||
if self.is_fsdp_enabled:
|
||||
if (
|
||||
@@ -524,8 +540,6 @@ class AxolotlTrainer(
|
||||
):
|
||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||
|
||||
return res
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def additional_accelerator_args(
|
||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||
@@ -567,10 +581,10 @@ class AxolotlTrainer(
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_memory_active"] = active
|
||||
logs["memory/max_memory_allocated"] = allocated
|
||||
logs["memory/device_memory_reserved"] = reserved
|
||||
except (ValueError, FileNotFoundError):
|
||||
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||
except (ValueError, TypeError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
@@ -590,3 +604,64 @@ class AxolotlTrainer(
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
else (PreTrainedModel, PeftModel)
|
||||
)
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
):
|
||||
self.accelerator.unwrap_model(
|
||||
self.model, keep_torch_compile=False
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||
)
|
||||
if self.args.save_safetensors:
|
||||
safetensors.torch.save_file(
|
||||
state_dict,
|
||||
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||
metadata={"format": "pt"},
|
||||
)
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
else:
|
||||
self.model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
safe_serialization=self.args.save_safetensors,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
|
||||
if self.processing_class is not None:
|
||||
self.processing_class.save_pretrained(output_dir)
|
||||
elif (
|
||||
self.data_collator is not None
|
||||
and hasattr(self.data_collator, "tokenizer")
|
||||
and self.data_collator.tokenizer is not None
|
||||
):
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Mixin for correctly saving fsdp
|
||||
"""
|
||||
|
||||
from accelerate import PartialState
|
||||
from transformers import Trainer
|
||||
|
||||
|
||||
@@ -18,3 +19,15 @@ class DistributedParallelMixin(Trainer):
|
||||
):
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
super()._save(output_dir, state_dict=state_dict)
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
super().create_accelerator_and_postprocess()
|
||||
if (
|
||||
self.accelerator.distributed_type == "FSDP"
|
||||
and self.accelerator.state.fsdp_plugin is None
|
||||
):
|
||||
# pylint: disable=protected-access
|
||||
# handle Context Parallelism without FSDP
|
||||
self.accelerator.state.distributed_type = "MULTI_GPU"
|
||||
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
||||
PartialState().distributed_type = "MULTI_GPU"
|
||||
|
||||
@@ -243,3 +243,18 @@ class AxolotlTrainingMixins:
|
||||
)
|
||||
|
||||
# end of multi-modal section
|
||||
|
||||
dion_learning_rate: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The learning rate for Dion"},
|
||||
)
|
||||
dion_momentum: float | None = field(
|
||||
default=None,
|
||||
metadata={"help": "The momentum for Dion"},
|
||||
)
|
||||
dion_rank_fraction: float | None = field(
|
||||
default=None,
|
||||
)
|
||||
dion_rank_multiple_of: int | None = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -26,9 +26,11 @@ import traceback
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
@@ -74,8 +76,8 @@ class BasePlugin:
|
||||
def __init__(self):
|
||||
"""Initializes the BasePlugin."""
|
||||
|
||||
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration.
|
||||
def register(self, cfg: dict): # pylint: disable=unused-argument
|
||||
"""Registers the plugin with the given configuration as an unparsed dict.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugin.
|
||||
@@ -641,3 +643,24 @@ class BaseOptimizerFactory:
|
||||
self, opt_model, training_args, **optimizer_kwargs
|
||||
) -> Optimizer | None:
|
||||
pass
|
||||
|
||||
# duplicated from transformers
|
||||
def get_decay_parameter_names(self, model) -> list[str]:
|
||||
"""
|
||||
Get all parameter names that weight decay will be applied to.
|
||||
|
||||
This function filters out parameters in two ways:
|
||||
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
|
||||
2. By parameter name patterns (containing 'bias', or variation of 'norm')
|
||||
"""
|
||||
forbidden_name_patterns = [
|
||||
r"bias",
|
||||
r"layernorm",
|
||||
r"rmsnorm",
|
||||
r"(?:^|\.)norm(?:$|\.)",
|
||||
r"_norm(?:$|\.)",
|
||||
]
|
||||
decay_parameters = get_parameter_names(
|
||||
model, [nn.LayerNorm], forbidden_name_patterns
|
||||
)
|
||||
return decay_parameters
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -31,6 +31,7 @@ plugins:
|
||||
|
||||
## Supported Models
|
||||
|
||||
- arcee
|
||||
- cohere
|
||||
- cohere2
|
||||
- gemma
|
||||
@@ -41,13 +42,17 @@ plugins:
|
||||
- gemma3n_text
|
||||
- glm
|
||||
- glm4
|
||||
- gpt_oss
|
||||
- granite
|
||||
- granitemoe
|
||||
- hunyuan_v1_dense
|
||||
- hunyuan_v1_moe
|
||||
- llama
|
||||
- llama4
|
||||
- llama4_text
|
||||
- mistral
|
||||
- mistral3
|
||||
- mixtral
|
||||
- mllama
|
||||
- phi
|
||||
- phi3
|
||||
|
||||
@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -284,12 +284,12 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
||||
return sample
|
||||
|
||||
def _tokenize_single_prompt(self, prompt):
|
||||
logprobs = prompt.pop(self.logprobs_field)
|
||||
target_token_ids = prompt.pop("target_token_ids")
|
||||
target_token_ids = prompt.get("target_token_ids", None)
|
||||
|
||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||
tokenized_prompt[self.logprobs_field] = logprobs
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||
|
||||
if target_token_ids is not None:
|
||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||
|
||||
return tokenized_prompt
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import Callable
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
from torch import nn
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
from .geglu import geglu_backward, geglu_forward
|
||||
from .quantize import dequantize
|
||||
@@ -25,6 +26,7 @@ def get_lora_parameters(
|
||||
proj: nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
QuantState | None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
@@ -37,39 +39,54 @@ def get_lora_parameters(
|
||||
proj: The projection module to extract parameters from.
|
||||
|
||||
Returns:
|
||||
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
||||
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
||||
available.
|
||||
A tuple containing the base weights, quantization state, LoRA A and B weights,
|
||||
scaling factor, and base layer bias. Quant state, weights, and bias may be
|
||||
`None` if not available.
|
||||
"""
|
||||
# For DPO or disabled adapters
|
||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||
W = base_layer.weight
|
||||
b = base_layer.bias
|
||||
|
||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
return W, quant_state, None, None, None
|
||||
return W, b, quant_state, None, None, None
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
active_adapter = (
|
||||
proj.active_adapters[0]
|
||||
if hasattr(proj, "active_adapters")
|
||||
else proj.active_adapter
|
||||
)
|
||||
A = proj.lora_A[active_adapter].weight
|
||||
B = proj.lora_B[active_adapter].weight
|
||||
|
||||
linear_A = proj.lora_A[active_adapter]
|
||||
linear_B = proj.lora_B[active_adapter]
|
||||
|
||||
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
|
||||
# We fuse linear layers + LoRA adapters calculations into a single
|
||||
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
|
||||
# Note that we don't apply resharding later in this module (it gets messy quickly),
|
||||
# but LoRA parameters are generally small enough that this is not an issue.
|
||||
if isinstance(linear_A.weight, DTensor):
|
||||
linear_A.unshard()
|
||||
linear_B.unshard()
|
||||
|
||||
A = linear_A.weight
|
||||
B = linear_B.weight
|
||||
s = proj.scaling[active_adapter]
|
||||
|
||||
quant_state = getattr(W, "quant_state", None)
|
||||
|
||||
return W, quant_state, A, B, s
|
||||
return W, b, quant_state, A, B, s
|
||||
|
||||
|
||||
def matmul_lora(
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
W_quant: QuantState,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
b: torch.Tensor | None,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
s: float | None,
|
||||
out: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -90,20 +107,22 @@ def matmul_lora(
|
||||
dtype = X.dtype
|
||||
W = dequantize(W.t(), W_quant)
|
||||
|
||||
reshape = False
|
||||
if X.dim() == 3:
|
||||
batch, seq_len, _ = X.shape
|
||||
X = X.view(-1, X.shape[-1])
|
||||
reshape = True
|
||||
else:
|
||||
reshape = False
|
||||
|
||||
out = torch.matmul(X, W, out=out)
|
||||
if W_quant is not None:
|
||||
del W
|
||||
|
||||
if A is not None:
|
||||
A, B = A.t(), B.t()
|
||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||
A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
|
||||
out += s * X @ A @ B
|
||||
|
||||
if b is not None:
|
||||
out += b
|
||||
|
||||
return out.view(batch, seq_len, -1) if reshape else out
|
||||
|
||||
@@ -117,17 +136,20 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
ctx,
|
||||
X: torch.Tensor,
|
||||
gate_weight: torch.Tensor,
|
||||
gate_quant: object | None,
|
||||
gate_bias: torch.Tensor | None,
|
||||
gate_quant: QuantState | None,
|
||||
gate_A: torch.Tensor | None,
|
||||
gate_B: torch.Tensor | None,
|
||||
gate_scale: float,
|
||||
up_weight: torch.Tensor,
|
||||
up_quant: object | None,
|
||||
up_bias: torch.Tensor | None,
|
||||
up_quant: QuantState | None,
|
||||
up_A: torch.Tensor | None,
|
||||
up_B: torch.Tensor | None,
|
||||
up_scale: float,
|
||||
down_weight: torch.Tensor,
|
||||
down_quant: object | None,
|
||||
down_bias: torch.Tensor | None,
|
||||
down_quant: QuantState | None,
|
||||
down_A: torch.Tensor | None,
|
||||
down_B: torch.Tensor | None,
|
||||
down_scale: float,
|
||||
@@ -142,20 +164,22 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
ctx: Autograd context
|
||||
X: Input features
|
||||
gate_weight: Gate projection weight
|
||||
gate_bias: Gate projection bias
|
||||
gate_quant: Gate quantization state
|
||||
gate_A: Gate LoRA A matrix
|
||||
gate_B: Gate LoRA B matrix
|
||||
gate_scale: Gate LoRA scale
|
||||
up_weight: Up-projection weight
|
||||
up_quant: Up-projection quantization state
|
||||
up_A: Up-projection LoRA A matrix
|
||||
up_B: Up-projection LoRA B matrix
|
||||
up_scale: Up-projection LoRA scale
|
||||
down_weight: Down-projection weight
|
||||
down_quant: Down-projection quantization state
|
||||
down_A: Down-projection LoRA A matrix
|
||||
down_B: Down-projection LoRA B matrix
|
||||
down_scale: Down-projection LoRA scale
|
||||
up_weight: Up projection weight
|
||||
up_quant: Up projection quantization state
|
||||
up_A: Up projection LoRA A matrix
|
||||
up_B: Up projection LoRA B matrix
|
||||
up_scale: Up projection LoRA scale
|
||||
down_weight: Down projection weight
|
||||
down_bias: Down projection bias
|
||||
down_quant: Down projection quantization state
|
||||
down_A: Down projection LoRA A matrix
|
||||
down_B: Down projection LoRA B matrix
|
||||
down_scale: Down projection LoRA scale
|
||||
activation_fn: Forward activation function
|
||||
activation_fn_backward: Backward activation function
|
||||
inplace: Whether to perform operations in-place
|
||||
@@ -164,15 +188,17 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
Output transformed by multi-layer perceptron and activation function
|
||||
"""
|
||||
# Compute projections
|
||||
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
||||
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
||||
gate = matmul_lora(
|
||||
X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
|
||||
)
|
||||
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
|
||||
|
||||
# Activation
|
||||
hidden = activation_fn(gate, up)
|
||||
|
||||
# Down projection
|
||||
output = matmul_lora(
|
||||
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
||||
hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
|
||||
)
|
||||
|
||||
# Save for backward
|
||||
@@ -195,22 +221,26 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Performs backward pass computation for LoRA MLP.
|
||||
@@ -222,7 +252,7 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
Returns:
|
||||
Tuple containing gradients for all inputs from forward pass:
|
||||
- Input gradient tensor (or `None`)
|
||||
- `None` for weights/quantization states
|
||||
- `None` for weights/biases/quantization states
|
||||
- LoRA A/B matrix gradients (or `None`)
|
||||
- `None` for scaling factors
|
||||
- `None` for activation functions and flags
|
||||
@@ -265,9 +295,10 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
dtype = X.dtype
|
||||
|
||||
# Down projection
|
||||
DW = matmul_lora(
|
||||
grad_down = matmul_lora(
|
||||
grad_output,
|
||||
down_weight.t(),
|
||||
None,
|
||||
down_quant,
|
||||
down_B,
|
||||
down_A,
|
||||
@@ -275,7 +306,7 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
)
|
||||
|
||||
# Activation backward
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
||||
h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
|
||||
|
||||
# Initialize and compute LoRA gradients
|
||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||
@@ -315,8 +346,8 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||
|
||||
# Gate projection gradients
|
||||
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
||||
dX += grad_gate @ gate_weight.t()
|
||||
gate_weight = dequantize(gate_weight, gate_quant)
|
||||
dX += grad_gate @ gate_weight
|
||||
del gate_weight
|
||||
|
||||
if gate_A is not None and gate_B is not None:
|
||||
@@ -334,22 +365,26 @@ class LoRA_MLP(torch.autograd.Function):
|
||||
dX,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_gate_A.t() if d_gate_A is not None else None,
|
||||
d_gate_B.t() if d_gate_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_up_A.t() if d_up_A is not None else None,
|
||||
d_up_B.t() if d_up_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_down_A.t() if d_down_A is not None else None,
|
||||
d_down_B.t() if d_down_B is not None else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@@ -364,23 +399,26 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateb,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upb,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downb,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
@@ -404,22 +442,25 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
|
||||
Returns:
|
||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||
"""
|
||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||
out = LoRA_MLP.apply(
|
||||
X,
|
||||
gateW,
|
||||
gateb,
|
||||
gateW_quant,
|
||||
gateA,
|
||||
gateB,
|
||||
gateS,
|
||||
upW,
|
||||
upb,
|
||||
upW_quant,
|
||||
upA,
|
||||
upB,
|
||||
upS,
|
||||
downW,
|
||||
downb,
|
||||
downW_quant,
|
||||
downA,
|
||||
downB,
|
||||
@@ -446,16 +487,19 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
q_bias: torch.Tensor | None,
|
||||
q_quant: QuantState | None,
|
||||
q_A: torch.Tensor | None,
|
||||
q_B: torch.Tensor | None,
|
||||
q_scale: float,
|
||||
k_weight: torch.Tensor,
|
||||
k_bias: torch.Tensor | None,
|
||||
k_quant: QuantState | None,
|
||||
k_A: torch.Tensor | None,
|
||||
k_B: torch.Tensor | None,
|
||||
k_scale: float,
|
||||
v_weight: torch.Tensor,
|
||||
v_bias: torch.Tensor | None,
|
||||
v_quant: QuantState | None,
|
||||
v_A: torch.Tensor | None,
|
||||
v_B: torch.Tensor | None,
|
||||
@@ -469,16 +513,19 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
q_weight: Query projection weight
|
||||
q_bias: Query projection bias
|
||||
q_quant: Query quantization state
|
||||
q_A: Query LoRA A matrix
|
||||
q_B: Query LoRA B matrix
|
||||
q_scale: Query LoRA scale
|
||||
k_weight: Key projection weight
|
||||
k_bias: Key projection bias
|
||||
k_quant: Key quantization state
|
||||
k_A: Key LoRA A matrix
|
||||
k_B: Key LoRA B matrix
|
||||
k_scale: Key LoRA scale
|
||||
v_weight: Value projection weight
|
||||
v_bias: Value projection bias
|
||||
v_quant: Value quantization state
|
||||
v_A: Value LoRA A matrix
|
||||
v_B: Value LoRA B matrix
|
||||
@@ -488,20 +535,21 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
||||
Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
|
||||
K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
|
||||
V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
|
||||
|
||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||
ctx.scales = (q_scale, k_scale, v_scale)
|
||||
ctx.quants = (q_quant, k_quant, v_quant)
|
||||
ctx.weights = (q_weight, k_weight, v_weight)
|
||||
ctx.biases = (q_bias, k_bias, v_bias)
|
||||
ctx.inplace = inplace
|
||||
|
||||
return Q, K, V
|
||||
|
||||
@staticmethod
|
||||
@torch_amp_custom_fwd
|
||||
@torch_amp_custom_bwd
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
q_grad: torch.Tensor,
|
||||
@@ -511,16 +559,19 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
@@ -608,31 +659,31 @@ class LoRA_QKV(torch.autograd.Function):
|
||||
# Transpose gradients if needed
|
||||
if d_A_q is not None:
|
||||
d_A_q = d_A_q.t()
|
||||
if d_B_q is not None:
|
||||
d_B_q = d_B_q.t()
|
||||
d_B_q = d_B_q.t() # type: ignore[union-attr]
|
||||
if d_A_k is not None:
|
||||
d_A_k = d_A_k.t()
|
||||
if d_B_k is not None:
|
||||
d_B_k = d_B_k.t()
|
||||
d_B_k = d_B_k.t() # type: ignore[union-attr]
|
||||
if d_A_v is not None:
|
||||
d_A_v = d_A_v.t()
|
||||
if d_B_v is not None:
|
||||
d_B_v = d_B_v.t()
|
||||
d_B_v = d_B_v.t() # type: ignore[union-attr]
|
||||
|
||||
return (
|
||||
grad_X.view(batch, seq_len, -1),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_q,
|
||||
d_B_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_k,
|
||||
d_B_k,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
d_A_v,
|
||||
d_B_v,
|
||||
None,
|
||||
@@ -653,22 +704,25 @@ def apply_lora_qkv(
|
||||
Returns:
|
||||
Tuple of (Query, Key, Value) projection tensors
|
||||
"""
|
||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||
KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||
VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||
Q, K, V = LoRA_QKV.apply(
|
||||
X,
|
||||
QW,
|
||||
Qb,
|
||||
QW_quant,
|
||||
QA,
|
||||
QB,
|
||||
QS,
|
||||
KW,
|
||||
Kb,
|
||||
KW_quant,
|
||||
KA,
|
||||
KB,
|
||||
KS,
|
||||
VW,
|
||||
Vb,
|
||||
VW_quant,
|
||||
VA,
|
||||
VB,
|
||||
@@ -688,10 +742,11 @@ class LoRA_O(torch.autograd.Function):
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
X: torch.Tensor,
|
||||
W: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
W_quant: QuantState | None,
|
||||
A: torch.Tensor | None,
|
||||
B: torch.Tensor | None,
|
||||
S: float,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for output projection with LoRA.
|
||||
@@ -700,19 +755,20 @@ class LoRA_O(torch.autograd.Function):
|
||||
ctx: Autograd context
|
||||
X: Input tensor
|
||||
W: Output projection weight
|
||||
b: Output projection bias
|
||||
W_quant: Weight quantization state
|
||||
A: LoRA A matrix
|
||||
B: LoRA B matrix
|
||||
S: LoRA scaling factor
|
||||
s: LoRA scaling factor
|
||||
|
||||
Returns:
|
||||
Output projection tensor
|
||||
Output projection result
|
||||
"""
|
||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
||||
XW = matmul_lora(X, W, b, W_quant, A, B, s)
|
||||
ctx.custom_saved_tensors = (
|
||||
W,
|
||||
W_quant,
|
||||
S,
|
||||
s,
|
||||
)
|
||||
ctx.save_for_backward(A, B, X)
|
||||
|
||||
@@ -727,8 +783,9 @@ class LoRA_O(torch.autograd.Function):
|
||||
torch.Tensor,
|
||||
None,
|
||||
None,
|
||||
torch.Tensor | None,
|
||||
torch.Tensor | None,
|
||||
None,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
@@ -741,7 +798,7 @@ class LoRA_O(torch.autograd.Function):
|
||||
Returns:
|
||||
Tuple containing gradients for all forward inputs
|
||||
"""
|
||||
W, W_quant, S = ctx.custom_saved_tensors
|
||||
W, W_quant, s = ctx.custom_saved_tensors
|
||||
A, B, X = ctx.saved_tensors
|
||||
|
||||
batch, seq_len, hd = X.shape
|
||||
@@ -751,17 +808,19 @@ class LoRA_O(torch.autograd.Function):
|
||||
|
||||
# Weight projection
|
||||
dY_X = X.t() @ dY
|
||||
d_A = S * dY_X @ B
|
||||
d_B = S * A @ dY_X
|
||||
d_A = s * dY_X @ B
|
||||
d_B = s * A @ dY_X
|
||||
|
||||
# Get derivative for dX
|
||||
W = dequantize(W.t(), W_quant)
|
||||
dX = dY @ W.t()
|
||||
del W
|
||||
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
||||
|
||||
# W, W_quant, A, B, S
|
||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
||||
A, B = A.to(dtype), B.to(dtype)
|
||||
dX += s * dY @ B @ A
|
||||
|
||||
# W, b, W_quant, A, B, s
|
||||
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
||||
|
||||
|
||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
@@ -774,7 +833,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||
Returns:
|
||||
Transformed output tensor
|
||||
"""
|
||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
||||
OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||
output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
|
||||
|
||||
return output
|
||||
|
||||
@@ -76,6 +76,7 @@ def load_lora(
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
@@ -106,6 +107,7 @@ def load_lora(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
target_parameters=lora_target_parameters,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
layers_pattern=cfg.peft_layers_pattern,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
|
||||
@@ -1,26 +1,13 @@
|
||||
"""Shared constants for axolotl.loaders module"""
|
||||
|
||||
from transformers import (
|
||||
Gemma3ForConditionalGeneration,
|
||||
Gemma3nForConditionalGeneration,
|
||||
Llama4ForConditionalGeneration,
|
||||
LlavaForConditionalGeneration,
|
||||
Mistral3ForConditionalGeneration,
|
||||
MllamaForConditionalGeneration,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
from transformers import AutoModelForImageTextToText
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
||||
)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
||||
"mllama": MllamaForConditionalGeneration,
|
||||
"llama4": Llama4ForConditionalGeneration,
|
||||
"llava": LlavaForConditionalGeneration,
|
||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
||||
"mistral3": Mistral3ForConditionalGeneration,
|
||||
"gemma3": Gemma3ForConditionalGeneration,
|
||||
"gemma3n": Gemma3nForConditionalGeneration,
|
||||
}
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
|
||||
|
||||
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
|
||||
|
||||
try:
|
||||
from transformers import VoxtralForConditionalGeneration
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Model loader class implementation for loading, configuring, and patching various
|
||||
models.
|
||||
"""
|
||||
Model loader class implementation for loading, configuring, and patching various models.
|
||||
"""
|
||||
|
||||
import gc
|
||||
@@ -13,7 +13,7 @@ import peft
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.modeling_utils
|
||||
from accelerate import PartialState, init_empty_weights
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from peft import (
|
||||
PeftConfig,
|
||||
@@ -22,8 +22,10 @@ from peft import (
|
||||
PeftModelForCausalLM,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from torch.distributed import DeviceMesh
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForImageTextToText,
|
||||
AutoModelForVision2Seq,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
@@ -49,7 +51,11 @@ from axolotl.loaders.utils import (
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
||||
from axolotl.utils.distributed import (
|
||||
build_parallelism_config,
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
@@ -87,6 +93,7 @@ class ModelLoader:
|
||||
|
||||
use_parallel_config: bool | None = False
|
||||
parallelism_config: ParallelismConfig | None = None
|
||||
device_mesh: DeviceMesh | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -202,8 +209,11 @@ class ModelLoader:
|
||||
self._set_device_map_config()
|
||||
if self.cfg.revision_of_model:
|
||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||
if self.cfg.use_kernels:
|
||||
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||
self._set_quantization_config()
|
||||
self._set_attention_config()
|
||||
self._check_model_requirements()
|
||||
|
||||
def _apply_post_model_load_setup(self):
|
||||
"""Configure the model after it has been loaded."""
|
||||
@@ -300,7 +310,10 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
# Handle DeepSpeed Zero3
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if (
|
||||
is_deepspeed_zero3_enabled()
|
||||
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||
):
|
||||
self._set_z3_leaf_modules()
|
||||
|
||||
# Apply gradient checkpointing if needed
|
||||
@@ -405,85 +418,12 @@ class ModelLoader:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
def _set_parallel_config(self):
|
||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
self.cfg.tensor_parallel_size,
|
||||
self.cfg.context_parallel_size,
|
||||
self.cfg.dp_shard_size,
|
||||
self.cfg.dp_replicate_size,
|
||||
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
self.parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
||||
partial_state = PartialState()
|
||||
# fmt: off
|
||||
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
||||
self.parallelism_config
|
||||
)
|
||||
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
||||
device_mesh
|
||||
)
|
||||
# fmt: on
|
||||
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
|
||||
if parallelism_config:
|
||||
self.parallelism_config = parallelism_config
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
def _set_auto_model_loader(self):
|
||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||
@@ -494,6 +434,8 @@ class ModelLoader:
|
||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||
self.model_config.model_type, AutoModelForVision2Seq
|
||||
)
|
||||
if isinstance(self.auto_model_loader, str):
|
||||
self.auto_model_loader = AutoModelForImageTextToText
|
||||
|
||||
def _set_device_map_config(self):
|
||||
"""Setup `device_map` according to config"""
|
||||
@@ -565,8 +507,17 @@ class ModelLoader:
|
||||
|
||||
def _set_quantization_config(self):
|
||||
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.model_quantization_config == "Mxfp4Config":
|
||||
from transformers import Mxfp4Config
|
||||
|
||||
mxfp4_kwargs = {}
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
else:
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
@@ -601,7 +552,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
||||
"load_in_4bit", False
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -627,7 +580,9 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
||||
"load_in_8bit", False
|
||||
):
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -648,7 +603,9 @@ class ModelLoader:
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.flex_attention:
|
||||
if self.cfg.attn_implementation:
|
||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||
elif self.cfg.flex_attention:
|
||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flex_attention"
|
||||
@@ -675,6 +632,16 @@ class ModelLoader:
|
||||
if self.cfg.low_cpu_mem_usage:
|
||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||
|
||||
def _check_model_requirements(self):
|
||||
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
|
||||
from transformers.utils.import_utils import is_causal_conv1d_available
|
||||
|
||||
if is_causal_conv1d_available():
|
||||
raise ImportError(
|
||||
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
|
||||
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
|
||||
)
|
||||
|
||||
def _configure_zero3_memory_efficient_loading(
|
||||
self,
|
||||
) -> HfTrainerDeepSpeedConfig | None:
|
||||
@@ -721,7 +688,7 @@ class ModelLoader:
|
||||
if self.cfg.tensor_parallel_size > 1:
|
||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||
self.model_kwargs["tp_plan"] = "auto"
|
||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||
self.model_kwargs["device_mesh"] = self.device_mesh
|
||||
if "device_map" in self.model_kwargs:
|
||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||
|
||||
@@ -737,6 +704,18 @@ class ModelLoader:
|
||||
elif self.is_qlora_and_fsdp_enabled:
|
||||
skip_move_to_device = True
|
||||
|
||||
if (
|
||||
self.cfg.tensor_parallel_size <= 1
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
and self.cfg.fsdp_version == 2
|
||||
):
|
||||
# setting device_map for TP is not supported
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
if local_rank == 0:
|
||||
self.model_kwargs["device_map"] = "cpu"
|
||||
else:
|
||||
self.model_kwargs["device_map"] = "meta"
|
||||
|
||||
if (
|
||||
self.is_qlora_and_fsdp_enabled
|
||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||
@@ -845,6 +824,9 @@ class ModelLoader:
|
||||
self.model._tp_size = self.cfg.tensor_parallel_size
|
||||
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
||||
|
||||
if self.cfg.experimental_skip_move_to_device is not None:
|
||||
skip_move_to_device = self.cfg.experimental_skip_move_to_device
|
||||
|
||||
return skip_move_to_device
|
||||
|
||||
def _set_z3_leaf_modules(self):
|
||||
|
||||
@@ -65,6 +65,7 @@ class PatchManager:
|
||||
self._patch_llama_derived_model()
|
||||
self._apply_mistral_cross_entropy_patch()
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_fsdp2_bnb_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
@@ -72,11 +73,19 @@ class PatchManager:
|
||||
self._apply_voxtral_patches()
|
||||
|
||||
def _apply_transformers_patches(self):
|
||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
||||
patch_prepare_from_posids,
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
patch_evaluation_loop,
|
||||
patch_maybe_log_save_evaluate,
|
||||
)
|
||||
|
||||
patch_prepare_from_posids()
|
||||
patch_fsdp2 = (
|
||||
self.cfg.torch_compile
|
||||
and self.cfg.fsdp_config
|
||||
and self.cfg.fsdp_version == 2
|
||||
)
|
||||
|
||||
patch_evaluation_loop(patch_fsdp2)
|
||||
patch_maybe_log_save_evaluate()
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
@@ -103,6 +112,14 @@ class PatchManager:
|
||||
|
||||
def _apply_fsdp_patches(self):
|
||||
"""Apply patches for FSDP configurations."""
|
||||
if self.cfg.context_parallel_size > 1 or (
|
||||
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
||||
):
|
||||
from axolotl.monkeypatch.accelerate.parallelism_config import (
|
||||
patch_parallelism_config,
|
||||
)
|
||||
|
||||
patch_parallelism_config()
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||
|
||||
@@ -260,6 +277,21 @@ class PatchManager:
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
def _apply_fsdp2_bnb_patches(self):
|
||||
"""Apply FSDP2 BNB patches."""
|
||||
if (
|
||||
self.cfg.fsdp_config
|
||||
and str(self.cfg.fsdp_version) == "2"
|
||||
and self.cfg.adapter == "qlora"
|
||||
):
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
)
|
||||
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import (
|
||||
@@ -330,31 +362,21 @@ class PatchManager:
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
def _patch_llama_flash_attention(self, packed=False):
|
||||
def _patch_llama_flash_attention(self):
|
||||
"""Apply Flash Attention patches for LLaMA models."""
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
if packed:
|
||||
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=True,
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||
)
|
||||
elif self.cfg.s2_attention:
|
||||
if self.cfg.s2_attention:
|
||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=False,
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||
use_shifted_sparse_attn=True,
|
||||
)
|
||||
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=False,
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||
)
|
||||
@@ -385,7 +407,7 @@ class PatchManager:
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
if self.cfg.flash_attention:
|
||||
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
|
||||
self._patch_llama_flash_attention()
|
||||
elif self.cfg.xformers_attention:
|
||||
self._patch_llama_xformers_attention()
|
||||
elif self.cfg.sample_packing:
|
||||
@@ -408,17 +430,12 @@ class PatchManager:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
is_xformers_swiglu_available,
|
||||
replace_llama_mlp_with_swiglu,
|
||||
replace_llama_qkv_with_fused,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||
LOG.info("Patching with SwiGLU...")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
if self.cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("Patching with fused QKV...")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
|
||||
def _apply_unsloth_patches(self, model):
|
||||
"""Apply unsloth optimization patches."""
|
||||
if self.cfg.unsloth_lora_mlp:
|
||||
|
||||
@@ -7,6 +7,7 @@ import functools
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -36,25 +37,49 @@ def fsdp2_load_full_state_dict(
|
||||
|
||||
meta_sharded_sd = model.state_dict()
|
||||
sharded_sd = {}
|
||||
for param_name, full_tensor in full_sd.items():
|
||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||
for param_name, sharded_meta_param in meta_sharded_sd.items():
|
||||
full_tensor = None
|
||||
if _accelerator.is_main_process:
|
||||
full_tensor = full_sd[param_name]
|
||||
full_tensor = full_tensor.to(sharded_meta_param.dtype)
|
||||
|
||||
if hasattr(sharded_meta_param, "device_mesh"):
|
||||
device_mesh = sharded_meta_param.device_mesh
|
||||
if _accelerator.is_main_process:
|
||||
full_tensor = full_tensor.to(device_mesh.device_type)
|
||||
else:
|
||||
full_tensor = torch.empty(
|
||||
sharded_meta_param.size(),
|
||||
device=device_mesh.device_type,
|
||||
dtype=sharded_meta_param.dtype,
|
||||
)
|
||||
sharded_param = distribute_tensor(
|
||||
full_tensor,
|
||||
sharded_meta_param.device_mesh,
|
||||
device_mesh,
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
else:
|
||||
sharded_param = full_tensor
|
||||
# Non-sharded parameters
|
||||
if _accelerator.is_main_process:
|
||||
sharded_param = full_tensor.to(torch.device("cuda"))
|
||||
else:
|
||||
# broadcast manually
|
||||
sharded_param = torch.empty_like(
|
||||
sharded_meta_param,
|
||||
device=torch.device("cuda"),
|
||||
dtype=sharded_meta_param.dtype,
|
||||
)
|
||||
dist.broadcast(sharded_param, src=0)
|
||||
|
||||
if offload_to_cpu:
|
||||
sharded_param = sharded_param.cpu()
|
||||
|
||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||
|
||||
del full_tensor
|
||||
full_sd[param_name] = None
|
||||
|
||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||
end_time = time.time()
|
||||
LOG.debug(
|
||||
|
||||
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
workaround to allow parallelism config for pure CP
|
||||
"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from accelerate import DistributedType
|
||||
|
||||
|
||||
def _validate_accelerator(self, accelerator):
|
||||
_warnings = set()
|
||||
if not accelerator.multi_device and self.total_size == 1:
|
||||
# No distributed setup, valid parallelism config
|
||||
return
|
||||
|
||||
# We need this to ensure DDP works
|
||||
if self.total_size == 1:
|
||||
self._set_size("dp_replicate", accelerator.num_processes)
|
||||
|
||||
if self.total_size != accelerator.num_processes:
|
||||
raise ValueError(
|
||||
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
||||
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
||||
f"dp_shard_size/tp_size/cp_size."
|
||||
)
|
||||
|
||||
# allow parallelism config when not using fsdp if using pure context parallelism
|
||||
allow_parallelism_config = False
|
||||
|
||||
if (
|
||||
self.cp_size > 1 # pylint: disable=chained-comparison
|
||||
and self.dp_shard_size <= 1
|
||||
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
|
||||
):
|
||||
allow_parallelism_config = True
|
||||
|
||||
if (
|
||||
self.total_size > 1
|
||||
and not allow_parallelism_config
|
||||
and not (accelerator.is_fsdp2 or accelerator.multi_device)
|
||||
):
|
||||
raise ValueError(
|
||||
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
||||
)
|
||||
|
||||
for parallelism, size in self._sizes.items():
|
||||
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
||||
_warnings.add(
|
||||
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
||||
)
|
||||
|
||||
if _warnings and accelerator.is_main_process:
|
||||
warnings.warn(
|
||||
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def patched_is_fsdp2(self) -> bool:
|
||||
"""
|
||||
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
|
||||
"""
|
||||
# The new logic checks if fsdp_plugin exists before accessing its attributes
|
||||
return (
|
||||
self.distributed_type == DistributedType.FSDP
|
||||
and self.fsdp_plugin
|
||||
and self.fsdp_plugin.fsdp_version == 2
|
||||
)
|
||||
|
||||
|
||||
def patch_parallelism_config():
|
||||
from accelerate.accelerator import AcceleratorState, ParallelismConfig
|
||||
|
||||
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||
144
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
144
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
||||
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||
|
||||
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
||||
Params4bit parameters.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def apply_init_sharded_param_patch():
|
||||
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
original_source = inspect.getsource(FSDPParam._init_sharded_param)
|
||||
original_source, _ = detab_code(original_source)
|
||||
|
||||
# Define the replacement
|
||||
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||
|
||||
patched_param_creation = """ import bitsandbytes as bnb
|
||||
if isinstance(param, bnb.nn.modules.Params4bit):
|
||||
self.sharded_param = bnb.nn.modules.Params4bit(
|
||||
data=sharded_param,
|
||||
requires_grad=param.requires_grad,
|
||||
quant_state=param.quant_state,
|
||||
blocksize=param.blocksize,
|
||||
compress_statistics=param.compress_statistics,
|
||||
quant_type=param.quant_type,
|
||||
quant_storage=param.quant_storage,
|
||||
module=param.module,
|
||||
bnb_quantized=param.bnb_quantized,
|
||||
)
|
||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||
else:
|
||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_param_creation in original_source:
|
||||
patched_source = original_source.replace(
|
||||
original_param_creation, patched_param_creation
|
||||
)
|
||||
patched_source = patched_source.replace(
|
||||
"def _init_sharded_param(",
|
||||
"def patched_init_sharded_param(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports
|
||||
module_name = FSDPParam.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
# Replace the method
|
||||
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||
|
||||
|
||||
def apply_init_unsharded_param_patch():
|
||||
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
# Get original source
|
||||
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
|
||||
original_source, _ = detab_code(original_source)
|
||||
|
||||
# Define the replacement
|
||||
original_param_creation = """ self._unsharded_param = nn.Parameter(
|
||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||
)"""
|
||||
|
||||
patched_param_creation = """ import bitsandbytes as bnb
|
||||
local_tensor = self.sharded_param._local_tensor
|
||||
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
|
||||
self._unsharded_param = bnb.nn.modules.Params4bit(
|
||||
data=unsharded_param,
|
||||
requires_grad=self.sharded_param.requires_grad,
|
||||
quant_state=local_tensor.quant_state,
|
||||
blocksize=local_tensor.blocksize,
|
||||
compress_statistics=local_tensor.compress_statistics,
|
||||
quant_type=local_tensor.quant_type,
|
||||
quant_storage=local_tensor.quant_storage,
|
||||
module=local_tensor.module,
|
||||
bnb_quantized=local_tensor.bnb_quantized,
|
||||
)
|
||||
else:
|
||||
self._unsharded_param = nn.Parameter(
|
||||
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||
)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_param_creation in original_source:
|
||||
patched_source = original_source.replace(
|
||||
original_param_creation, patched_param_creation
|
||||
)
|
||||
patched_source = patched_source.replace(
|
||||
"def init_unsharded_param(",
|
||||
"def patched_init_unsharded_param(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Load necessary imports
|
||||
module_name = FSDPParam.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in patched_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
# Replace the method
|
||||
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||
else:
|
||||
LOG.warning("Could not find target code for patching")
|
||||
@@ -3,39 +3,26 @@
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaMLP,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from axolotl.monkeypatch.utils import set_module_name
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
except ImportError:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
||||
)
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
@@ -82,19 +69,6 @@ def replace_llama_mlp_with_swiglu(model):
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
def replace_llama_qkv_with_fused(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
qkv = FusedAttention(
|
||||
module.config,
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
module.v_proj,
|
||||
module.o_proj,
|
||||
)
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def patch_fa_llama_cross_entropy():
|
||||
LOG.info(
|
||||
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||
@@ -142,7 +116,6 @@ def patch_llama_rms_norm():
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
use_shifted_sparse_attn: Optional[bool] = False,
|
||||
@@ -154,16 +127,6 @@ def replace_llama_attn_with_flash_attn(
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
flashattn_forward_with_s2attn
|
||||
)
|
||||
else:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
|
||||
if packed:
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||
llama_model_forward
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
@@ -174,49 +137,6 @@ def replace_llama_attn_with_flash_attn(
|
||||
patch_llama_rms_norm()
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
"""
|
||||
Fused QKV Attention layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.init_device = next(iter(q.state_dict().values())).device
|
||||
|
||||
# define equivalent fused qkv projection
|
||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||
)
|
||||
self.o_proj = o
|
||||
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.qkv_proj.weight.data = torch.cat(
|
||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
q_proj, k_proj, v_proj = torch.split(
|
||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||
)
|
||||
|
||||
new_attn = LlamaAttention(self.config)
|
||||
new_attn.q_proj.weight.data = q_proj
|
||||
new_attn.k_proj.weight.data = k_proj
|
||||
new_attn.v_proj.weight.data = v_proj
|
||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||
|
||||
set_module_name(model, name, new_attn)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -355,576 +275,3 @@ def flashattn_forward_with_s2attn(
|
||||
.reshape(bsz, q_len, nheads, self.head_dim)
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
if isinstance(self, FusedAttention):
|
||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||
self.out_features, dim=-1
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# flash-attn v2 start
|
||||
#
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
#
|
||||
# flash-attn v2 end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[ # pylint: disable=unused-argument
|
||||
torch.LongTensor
|
||||
] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(
|
||||
*inputs,
|
||||
)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
padding_mask,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
"""
|
||||
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -156,6 +156,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
|
||||
return Llama4TextAttention
|
||||
|
||||
if model_type == "mistral3":
|
||||
from transformers.models.mistral.modeling_mistral import MistralAttention
|
||||
|
||||
return MistralAttention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
@@ -390,7 +395,6 @@ def apply_lora_kernel_patches(
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
@@ -400,7 +404,8 @@ def apply_lora_kernel_patches(
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
||||
"Cannot patch some attention QKV projections - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
@@ -409,7 +414,6 @@ def apply_lora_kernel_patches(
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and getattr(module, "base_layer", module).bias is None
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
@@ -418,14 +422,14 @@ def apply_lora_kernel_patches(
|
||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
||||
"Cannot patch some attention output projection - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
)
|
||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and getattr(proj, "base_layer", proj).bias is None
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
@@ -435,7 +439,8 @@ def apply_lora_kernel_patches(
|
||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
||||
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
||||
"lora_magnitude_vector (DoRA)"
|
||||
)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
@@ -3,53 +3,14 @@
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||
mistral_model_forward
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral_cross_entropy():
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
@@ -57,604 +18,3 @@ def patch_mistral_cross_entropy():
|
||||
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None or sliding_window is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self: OriginalMistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
getattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
def output_pad_fn(output_unpad):
|
||||
return rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def mistral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[ # pylint: disable=unused-argument
|
||||
torch.LongTensor
|
||||
] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = (
|
||||
self._gradient_checkpointing_func( # pylint: disable=protected-access
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
"""
|
||||
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -36,8 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"glm",
|
||||
"glm4",
|
||||
"smollm3",
|
||||
"granite",
|
||||
"granitemoe",
|
||||
"gpt_oss",
|
||||
"arcee",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func
|
||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
||||
|
||||
try:
|
||||
try: # pylint: disable=duplicate-code
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
except ImportError:
|
||||
_flash_supports_window = True
|
||||
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
|
||||
@@ -15,12 +15,15 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import DeviceMesh
|
||||
|
||||
try:
|
||||
try: # pylint: disable=duplicate-code
|
||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||
except ImportError:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_flash_supports_window_size as _flash_supports_window,
|
||||
)
|
||||
except ImportError:
|
||||
_flash_supports_window = True
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""
|
||||
fix for FSDP2 evals when using torch.compile
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
model.eval()
|
||||
"""
|
||||
|
||||
PATCHED_TRAINER_CODE = """
|
||||
if hasattr(model, "eval") and callable(model.eval):
|
||||
self.model.eval()
|
||||
"""
|
||||
|
||||
|
||||
def get_evaluation_loop_code() -> str:
|
||||
training_loop = inspect.getsource(Trainer.evaluation_loop)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
eval_loop = get_evaluation_loop_code()
|
||||
eval_loop, _ = detab_code(eval_loop)
|
||||
return ORIGINAL_TRAINER_CODE in eval_loop
|
||||
|
||||
|
||||
def patch_evaluation_loop_for_fsdp2():
|
||||
"""
|
||||
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
|
||||
"""
|
||||
|
||||
try:
|
||||
evaluation_loop = get_evaluation_loop_code()
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
|
||||
evaluation_loop
|
||||
)
|
||||
evaluation_loop, _ = detab_code(evaluation_loop)
|
||||
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
|
||||
return
|
||||
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
|
||||
)
|
||||
evaluation_loop = evaluation_loop.replace(
|
||||
"def evaluation_loop(",
|
||||
"def _fixed_evaluation_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in evaluation_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
||||
Trainer.evaluation_loop = ( # pylint: disable=protected-access
|
||||
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -1,87 +0,0 @@
|
||||
"""
|
||||
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
||||
|
||||
see https://github.com/huggingface/transformers/pull/39653/files
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _prepare_from_posids(query, key, value, position_ids):
|
||||
"""
|
||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
||||
All three query, key, value states will be flattened.
|
||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
||||
Arguments:
|
||||
query (`torch.Tensor`):
|
||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
||||
position_ids (`torch.Tensor`):
|
||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
||||
Return:
|
||||
query (`torch.Tensor`):
|
||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
||||
key (`torch.Tensor`):
|
||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
value (`torch.Tensor`):
|
||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
||||
indices_q (`torch.Tensor`):
|
||||
The indices of non-masked tokens from the flattened input target sequence.
|
||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
||||
"""
|
||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
||||
|
||||
position_ids = position_ids.flatten()
|
||||
indices_q = torch.arange(
|
||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seq_lens = torch.cat(
|
||||
(
|
||||
indices_q[position_ids == 0],
|
||||
torch.tensor(
|
||||
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
||||
),
|
||||
)
|
||||
)
|
||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
||||
# for some models (e.g. qwen2-vl).
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
return (
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
indices_q,
|
||||
(cu_seq_lens, cu_seq_lens),
|
||||
(max_length, max_length),
|
||||
)
|
||||
|
||||
|
||||
def patch_prepare_from_posids():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
||||
_prepare_from_posids
|
||||
)
|
||||
setattr(
|
||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
||||
"_prepare_from_posids",
|
||||
_prepare_from_posids,
|
||||
)
|
||||
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Module for patching transformers Trainer loss calculation to use nanmean.
|
||||
|
||||
This is needed for context parallelism since chunks of the input sequences may be fully
|
||||
masked and return NaNs in the loss calculation.
|
||||
|
||||
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
|
||||
the other evaluation_loop patch because we can't patch the same code twice without
|
||||
raising an OSError.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_EVAL_CODE = {
|
||||
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
|
||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
|
||||
}
|
||||
PATCHED_EVAL_CODE = {
|
||||
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
|
||||
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
||||
}
|
||||
|
||||
ORIGINAL_FSDP2_CODE = """
|
||||
model.eval()
|
||||
"""
|
||||
|
||||
PATCHED_FSDP2_CODE = """
|
||||
if hasattr(model, "eval") and callable(model.eval):
|
||||
self.model.eval()
|
||||
"""
|
||||
|
||||
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
||||
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
||||
|
||||
|
||||
def check_evaluation_loop_is_patchable() -> bool:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
|
||||
|
||||
|
||||
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_evaluation_loop(patch_fsdp2: bool):
|
||||
"""Patch the evaluation_loop method."""
|
||||
# Check if already patched
|
||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||
LOG.info("Trainer.evaluation_loop already patched")
|
||||
return
|
||||
|
||||
# Check if the patterns exist
|
||||
try:
|
||||
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||
except OSError:
|
||||
return
|
||||
Trainer.evaluation = evaluation_loop_source
|
||||
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||
|
||||
# Apply the nanmean patches
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
|
||||
)
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
|
||||
)
|
||||
|
||||
# Apply FSDP2 eval guard patch if needed
|
||||
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
|
||||
)
|
||||
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
|
||||
|
||||
# Rename the function to avoid conflicts
|
||||
evaluation_loop_source = evaluation_loop_source.replace(
|
||||
"def evaluation_loop(",
|
||||
"def axolotl_evaluation_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Get the module for necessary imports
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Import necessary items from the module
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in evaluation_loop_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Execute the imports and patched method
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||
Trainer.evaluation_loop = (
|
||||
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def check_maybe_log_save_evaluate_is_patchable() -> bool:
|
||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||
return ORIGINAL_MAYBE_CODE in maybe_log_source
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_maybe_log_save_evaluate():
|
||||
"""Patch the _maybe_log_save_evaluate method."""
|
||||
# Check if already patched
|
||||
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
|
||||
LOG.info("Trainer._maybe_log_save_evaluate already patched")
|
||||
return
|
||||
|
||||
# Check if the patterns exist
|
||||
try:
|
||||
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||
except OSError:
|
||||
return
|
||||
Trainer._original_maybe_log_save_evaluate = maybe_log_source
|
||||
maybe_log_source, _ = detab_code(maybe_log_source)
|
||||
|
||||
# Apply the patch
|
||||
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
|
||||
|
||||
# Rename the function to avoid conflicts
|
||||
maybe_log_source = maybe_log_source.replace(
|
||||
"def _maybe_log_save_evaluate(",
|
||||
"def axolotl_maybe_log_save_evaluate(",
|
||||
1,
|
||||
)
|
||||
|
||||
# Get the module for necessary imports
|
||||
module_name = Trainer.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Import necessary items from the module
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in maybe_log_source:
|
||||
items_to_import.append(item)
|
||||
|
||||
# Execute the imports and patched method
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
from PIL import Image, ImageOps
|
||||
from PIL.Image import Resampling
|
||||
from torch import Tensor, zeros_like
|
||||
from transformers import ProcessorMixin, VoxtralProcessor
|
||||
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
|
||||
from transformers.image_utils import load_image
|
||||
|
||||
from axolotl.utils.dict import remove_none_values
|
||||
@@ -138,7 +138,7 @@ class ProcessingStrategy:
|
||||
image_key = key
|
||||
break
|
||||
|
||||
# if the image key exists, add the image to the first message
|
||||
# if the image key exists, add the image to the first user message
|
||||
if image_key is not None and processed_example[image_key] is not None:
|
||||
# TODO: check if it's normal to be single image only for common datasets
|
||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||
@@ -179,26 +179,34 @@ class ProcessingStrategy:
|
||||
|
||||
# Look for any image type in the first message
|
||||
# some dataset have an {type: "image"} in the first message
|
||||
msg_ind_to_add = None
|
||||
ind_to_add = None
|
||||
first_user_idx = None
|
||||
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][0]["content"]
|
||||
):
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
for msg_idx, msg_content in enumerate(processed_example["messages"]):
|
||||
if first_user_idx is None and msg_content["role"] == "user":
|
||||
first_user_idx = msg_idx
|
||||
for i, content in enumerate(
|
||||
processed_example["messages"][msg_idx]["content"]
|
||||
):
|
||||
ind_to_add = i
|
||||
break
|
||||
# Usually datasets created with image columns, don't have it in the messages itself
|
||||
if content["type"] == "image" and all(
|
||||
k not in content for k in ["image", "url", "path", "base64"]
|
||||
):
|
||||
msg_ind_to_add = msg_idx
|
||||
ind_to_add = i
|
||||
break
|
||||
|
||||
# If an image type is found, add the image to that index
|
||||
if ind_to_add is not None:
|
||||
processed_example["messages"][0]["content"][ind_to_add][
|
||||
"image"
|
||||
] = image_value
|
||||
if ind_to_add is not None and msg_ind_to_add is not None:
|
||||
processed_example["messages"][msg_ind_to_add]["content"][
|
||||
ind_to_add
|
||||
]["image"] = image_value
|
||||
else:
|
||||
# if no image type is found, add it to end of the first message
|
||||
processed_example["messages"][0]["content"].append(
|
||||
# if no image type is found, add it to end of the first user message
|
||||
if first_user_idx is None:
|
||||
first_user_idx = 0
|
||||
processed_example["messages"][first_user_idx]["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_value,
|
||||
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
|
||||
return labels
|
||||
|
||||
|
||||
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
|
||||
"""Processing Strategy class for SmolVLM2"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor: ProcessorMixin,
|
||||
chat_template: Optional[str] = None,
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||
self.image_token = "<image>" # nosec
|
||||
|
||||
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
||||
processor.tokenizer.additional_special_tokens.index(self.image_token)
|
||||
]
|
||||
|
||||
|
||||
def get_processing_strategy(
|
||||
processor: ProcessorMixin,
|
||||
chat_template,
|
||||
@@ -402,32 +428,43 @@ def get_processing_strategy(
|
||||
image_size: int | tuple[int, int] | None = None,
|
||||
image_resize_algorithm: Resampling | None = None,
|
||||
):
|
||||
processing_kwargs = {
|
||||
"processor": processor,
|
||||
"chat_template": chat_template,
|
||||
"image_size": image_size,
|
||||
"image_resize_algorithm": image_resize_algorithm,
|
||||
}
|
||||
|
||||
if chat_template_type in [None, "tokenizer_default"] and hasattr(
|
||||
processor.tokenizer, "chat_template"
|
||||
):
|
||||
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
|
||||
|
||||
if chat_template_type == "qwen2_vl":
|
||||
return Qwen2VLProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
if chat_template_type == "gemma3":
|
||||
return Gemma3ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
if chat_template_type == "gemma3n":
|
||||
return Gemma3nProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
)
|
||||
if chat_template_type in [
|
||||
"llama3_2_vision",
|
||||
"llama4",
|
||||
"llava",
|
||||
"mistral_v7_tekken",
|
||||
"pixtral",
|
||||
]:
|
||||
return ProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(processor, VoxtralProcessor):
|
||||
return VoxtralProcessingStrategy(
|
||||
processor, chat_template, image_size, image_resize_algorithm
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
||||
if isinstance(processor, SmolVLMProcessor):
|
||||
return SmolVLM2ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
# llama3_2_vision, llama4, llava
|
||||
# mistral_v7_tekken, pixtral, lfm2vl
|
||||
return ProcessingStrategy(
|
||||
**processing_kwargs,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
field_messages: str = "messages",
|
||||
field_system: str = "system",
|
||||
field_tools: str = "tools",
|
||||
field_thinking: str = "reasoning_content",
|
||||
roles: dict[str, list[str]] | None = None,
|
||||
template_thinking_key: str | None = "reasoning_content",
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
@@ -50,8 +52,9 @@ class ChatTemplatePrompter(Prompter):
|
||||
message_property_mappings = {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
"reasoning_content": "reasoning_content",
|
||||
}
|
||||
if template_thinking_key and field_thinking:
|
||||
message_property_mappings[template_thinking_key] = field_thinking
|
||||
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
@@ -74,10 +77,12 @@ class ChatTemplatePrompter(Prompter):
|
||||
self.field_messages = field_messages
|
||||
self.field_system = field_system
|
||||
self.field_tools = field_tools
|
||||
self.field_thinking = field_thinking
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin | None = processor
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_kwargs = chat_template_kwargs or {}
|
||||
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
@@ -124,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
if hasattr(batch, "to_dict"):
|
||||
batch = batch.to_dict()
|
||||
else:
|
||||
batch = dict(batch)
|
||||
|
||||
# workaround since processor works in batches instead of single examples
|
||||
out = {}
|
||||
for k, val in batch.items():
|
||||
if k in ["pixel_values"]:
|
||||
batch[k] = val.tolist()
|
||||
if hasattr(val, "tolist"):
|
||||
out[k] = (
|
||||
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
|
||||
)
|
||||
else:
|
||||
batch[k] = val.squeeze().tolist()
|
||||
return batch
|
||||
out[k] = val
|
||||
return out
|
||||
|
||||
return self.tokenizer.apply_chat_template(
|
||||
conversation,
|
||||
@@ -428,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
|
||||
else:
|
||||
input_ids = tokenized_res["input_ids"]
|
||||
tokenized_prompt = tokenized_res
|
||||
tokenized_prompt = dict(tokenized_res)
|
||||
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
if isinstance(prompt_ids, dict):
|
||||
user_prompt_len = len(prompt_ids["input_ids"])
|
||||
else:
|
||||
user_prompt_len = len(prompt_ids)
|
||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||
else:
|
||||
labels = input_ids
|
||||
@@ -742,7 +758,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
# get the thinking content
|
||||
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
||||
transformed_message["reasoning_content"] = thinking_content.strip()
|
||||
transformed_message[self.prompter.template_thinking_key] = (
|
||||
thinking_content.strip()
|
||||
)
|
||||
|
||||
# take remainder of the content
|
||||
# strip whitespace from beginning of the remainder (thinking tokens)
|
||||
@@ -953,6 +971,10 @@ class StrategyLoader:
|
||||
None,
|
||||
),
|
||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
|
||||
"template_thinking_key": dataset_config.get(
|
||||
"template_thinking_key", "reasoning_content"
|
||||
),
|
||||
"roles": dataset_config.get("roles"),
|
||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
|
||||
@@ -4,11 +4,14 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import typing
|
||||
import weakref
|
||||
from collections import OrderedDict
|
||||
from contextlib import ExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
@@ -38,6 +41,7 @@ from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
@@ -46,7 +50,7 @@ except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -124,32 +128,6 @@ def setup_reference_model(
|
||||
return model_ref
|
||||
|
||||
|
||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
||||
"""
|
||||
Determine the checkpoint to resume from based on configuration.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
||||
"""
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
return cfg.resume_from_checkpoint
|
||||
|
||||
|
||||
def setup_signal_handler(
|
||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||
):
|
||||
@@ -218,6 +196,7 @@ def execute_training(
|
||||
ring_attn_func=cfg.ring_attn_func,
|
||||
heads_k_stride=cfg.heads_k_stride,
|
||||
gather_outputs=cfg.rl is RLType.GRPO,
|
||||
device_mesh=trainer.accelerator.torch_device_mesh,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -274,19 +253,56 @@ def save_trained_model(
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
return
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
||||
if cfg.fsdp_config or cfg.fsdp:
|
||||
if cfg.fsdp_config.final_state_dict_type:
|
||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||
else:
|
||||
state_dict_type = cfg.fsdp_config.state_dict_type
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||
trainer.save_model(cfg.output_dir)
|
||||
trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT
|
||||
if state_dict_type == "SHARDED_STATE_DICT":
|
||||
LOG.info(
|
||||
"The final model was saved with a sharded state dict. Please ensure you merge "
|
||||
"the sharded weights with `merge-sharded-fsdp-weights`."
|
||||
)
|
||||
checkpoint_dir = determine_last_checkpoint(cfg, update=False)
|
||||
if (
|
||||
not (Path(cfg.output_dir) / "model.safetensors.index.json").exists()
|
||||
and checkpoint_dir
|
||||
):
|
||||
# import here to prevent circular import
|
||||
from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights
|
||||
|
||||
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||
merged_path = str(Path(cfg.output_dir) / "merged")
|
||||
merge_fsdp_weights(
|
||||
checkpoint_dir=str(fsdp_dir),
|
||||
output_path=merged_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
if trainer.accelerator.is_main_process:
|
||||
# move all files in merged_path to cfg.output_dir
|
||||
for merged_file in Path(merged_path).iterdir():
|
||||
shutil.move(str(merged_file), cfg.output_dir)
|
||||
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||
# cleanup the FSDP prefix in the model config.json
|
||||
if trainer.accelerator.is_main_process:
|
||||
with open(
|
||||
Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
|
||||
) as config_file_io:
|
||||
# read the model config as an OrderedDict
|
||||
config = json.load(config_file_io, object_pairs_hook=OrderedDict)
|
||||
config["architectures"] = [
|
||||
name.lstrip("FSDP") for name in config["architectures"]
|
||||
]
|
||||
# write the updated model config back
|
||||
with open(
|
||||
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
|
||||
) as config_file_io:
|
||||
json.dump(config, config_file_io, indent=2)
|
||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
@@ -563,9 +579,13 @@ def train(
|
||||
setup_model_card(cfg)
|
||||
|
||||
# Execute the training
|
||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
||||
resume_from_checkpoint = determine_last_checkpoint(cfg)
|
||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||
|
||||
# clear cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save the trained model and cleanup
|
||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||
create_model_card(cfg, trainer)
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
{# Alias tools -> available_tools #}
|
||||
{%- if tools and not available_tools -%}
|
||||
{%- set available_tools = tools -%}
|
||||
{%- endif -%}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content'] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "Knowledge Cutoff Date: April 2024.
|
||||
Today's Date: " + strftime_now('%B %d, %Y') + ".
|
||||
You are Granite, developed by IBM." %}
|
||||
{%- if available_tools and documents %}
|
||||
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
|
||||
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- elif available_tools %}
|
||||
{%- set system_message = system_message + " You are a helpful assistant with access to the following tools. When a tool is required to answer the user's query, respond only with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
|
||||
{%- elif documents %}
|
||||
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- elif thinking %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant.
|
||||
Respond to every user query in a comprehensive and detailed way. You can write down your thoughts and reasoning process before responding. In the thought process, engage in a comprehensive cycle of analysis, summarization, exploration, reassessment, reflection, backtracing, and iteration to develop well-considered thinking process. In the response section, based on various attempts, explorations, and reflections from the thoughts section, systematically present the final solution that you deem correct. The response should summarize the thought process. Write your thoughts between <think></think> and write your response between <response></response> for each user query." %}
|
||||
{%- else %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant." %}
|
||||
{%- endif %}
|
||||
{%- if 'citations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
Use the symbols <|start_of_cite|> and <|end_of_cite|> to indicate when a fact comes from a document in the search result, e.g <|start_of_cite|> {document_id: 1}my fact <|end_of_cite|> for a fact from document 1. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
|
||||
{%- endif %}
|
||||
{%- if 'hallucinations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
Finally, after the response is written, include a numbered list of sentences from the response with a corresponding risk value that are hallucinated and not based in the documents.' %}
|
||||
{%- endif %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if available_tools %}
|
||||
{{- '<|start_of_role|>available_tools<|end_of_role|>' }}
|
||||
{{- available_tools | tojson(indent=4) }}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endif %}
|
||||
{%- if documents %}
|
||||
{%- for document in documents %}
|
||||
{{- '<|start_of_role|>document {"document_id": "' + document['doc_id'] | string + '"}<|end_of_role|>
|
||||
' }}
|
||||
{{- document['text'] }}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- for message in loop_messages %}
|
||||
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if loop.last and add_generation_prompt %}
|
||||
{{- '<|start_of_role|>assistant' }}
|
||||
{%- if controls %}
|
||||
{{- ' ' + controls | tojson()}}
|
||||
{%- endif %}
|
||||
{{- '<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
@@ -1,64 +0,0 @@
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content'] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "Knowledge Cutoff Date: April 2024.
|
||||
Today's Date: " + strftime_now('%B %d, %Y') + ".
|
||||
You are Granite, developed by IBM." %}
|
||||
{%- if tools and documents %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.
|
||||
|
||||
Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- elif tools %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
|
||||
{%- elif documents %}
|
||||
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
|
||||
{%- else %}
|
||||
{%- set system_message = system_message + " You are a helpful AI assistant." %}
|
||||
{%- endif %}
|
||||
{%- if 'citations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
|
||||
In your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
|
||||
{%- endif %}
|
||||
{%- if 'hallucinations' in controls and documents %}
|
||||
{%- set system_message = system_message + '
|
||||
|
||||
Finally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
|
||||
{%- endif %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if tools %}
|
||||
{{- '<|start_of_role|>tools<|end_of_role|>' }}
|
||||
{{- tools | tojson(indent=4) }}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endif %}
|
||||
{%- if documents %}
|
||||
{{- '<|start_of_role|>documents<|end_of_role|>' }}
|
||||
{%- for document in documents %}
|
||||
{{- 'Document ' + loop.index0 | string + '
|
||||
' }}
|
||||
{{- document['text'] }}
|
||||
{%- if not loop.last %}
|
||||
{{- '
|
||||
|
||||
'}}
|
||||
{%- endif%}
|
||||
{%- endfor %}
|
||||
{{- '<|end_of_text|>
|
||||
' }}
|
||||
{%- endif %}
|
||||
{%- for message in loop_messages %}
|
||||
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>
|
||||
' }}
|
||||
{%- if loop.last and add_generation_prompt %}
|
||||
{{- '<|start_of_role|>assistant' }}
|
||||
{%- if controls %}
|
||||
{{- ' ' + controls | tojson()}}
|
||||
{%- endif %}
|
||||
{{- '<|end_of_role|>' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
@@ -161,6 +161,8 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
squash_position_ids: bool = False
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features: List[List[dict]] = [features]
|
||||
@@ -176,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
elif feature == "position_ids" and self.squash_position_ids:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
# concatenate, get total length and create arange of new total position ids
|
||||
position_ids = np.concatenate(arrays)
|
||||
total_length = position_ids.shape[0]
|
||||
position_ids = np.arange(total_length)
|
||||
out_features[i][feature] = position_ids
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
|
||||
@@ -5,7 +5,6 @@ Collators for multi-modal chat messages and packing
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
@@ -42,62 +41,19 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
examples = self.processing_strategy(examples)
|
||||
|
||||
# Initialize batch
|
||||
batch: dict[str, Any] = {}
|
||||
messages = [ex["messages"] for ex in examples]
|
||||
|
||||
# Process each example
|
||||
for example in examples:
|
||||
# Apply chat template to process the example
|
||||
# This method requires transformers>=4.49.0
|
||||
result = self.processing_strategy.processor.apply_chat_template(
|
||||
example["messages"],
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
|
||||
# TODO: Check if need handling for len(input_ids) > sequence_len
|
||||
|
||||
# Add the processed tensors to our batch
|
||||
for key in result.keys():
|
||||
if key not in batch:
|
||||
batch[key] = []
|
||||
|
||||
batch[key].append(result[key].squeeze(0))
|
||||
|
||||
# Pad sequences to the same length
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["input_ids"],
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id,
|
||||
batch = self.processing_strategy.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_dict=True,
|
||||
chat_template=self.processing_strategy.chat_template,
|
||||
)
|
||||
|
||||
attention_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
batch["attention_mask"], batch_first=True, padding_value=0
|
||||
)
|
||||
|
||||
# Create the final batch
|
||||
final_batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
for key, val in batch.items():
|
||||
if key in ["input_ids", "attention_mask"]:
|
||||
continue
|
||||
|
||||
if key in ["token_type_ids", "cross_attention_mask"]:
|
||||
final_batch[key] = torch.nn.utils.rnn.pad_sequence(
|
||||
val, batch_first=True, padding_value=0
|
||||
)
|
||||
else:
|
||||
final_batch[key] = torch.stack(val)
|
||||
|
||||
# Process the labels
|
||||
final_batch["labels"] = self.processing_strategy.process_labels(
|
||||
final_batch["input_ids"]
|
||||
)
|
||||
batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"])
|
||||
|
||||
return final_batch
|
||||
return batch
|
||||
|
||||
@@ -5,8 +5,8 @@ import inspect
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from torch import nn
|
||||
from torch.distributed import DeviceMesh
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import ModelOutput
|
||||
@@ -194,6 +194,7 @@ class SequenceParallelContextManager:
|
||||
ring_attn_func: RingAttnFunc,
|
||||
heads_k_stride: int | None,
|
||||
gather_outputs: bool,
|
||||
device_mesh: DeviceMesh | None = None,
|
||||
):
|
||||
self.models = models
|
||||
self.context_parallel_size = context_parallel_size
|
||||
@@ -201,6 +202,7 @@ class SequenceParallelContextManager:
|
||||
self.ring_attn_func = ring_attn_func
|
||||
self.heads_k_stride = heads_k_stride
|
||||
self.gather_outputs = gather_outputs
|
||||
self.device_mesh = device_mesh
|
||||
|
||||
self._register_ring_attn()
|
||||
|
||||
@@ -240,9 +242,8 @@ class SequenceParallelContextManager:
|
||||
|
||||
def _register_ring_attn(self):
|
||||
# Initialize ring attn for sequence parallelism
|
||||
partial_state = PartialState()
|
||||
register_ring_attn_from_device_mesh(
|
||||
device_mesh=partial_state.device_mesh,
|
||||
device_mesh=self.device_mesh,
|
||||
context_parallel_dim=("cp",),
|
||||
heads_k_stride=self.heads_k_stride,
|
||||
ring_attn_func=self.ring_attn_func,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Data handling specific to SFT."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
@@ -27,7 +28,7 @@ from axolotl.utils.data.shared import (
|
||||
)
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
drop_long_seq_in_dataset,
|
||||
handle_long_seq_in_dataset,
|
||||
retry_on_request_exceptions,
|
||||
)
|
||||
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||
@@ -104,6 +105,9 @@ def _prepare_standard_dataset(
|
||||
finally:
|
||||
loader.cleanup()
|
||||
|
||||
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
|
||||
return train_dataset, eval_dataset, -1, prompters
|
||||
|
||||
# Validate sample packing configuration for evaluation
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
@@ -335,9 +339,9 @@ def _load_raw_datasets(
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
if split == "test" and cfg.eval_sequence_len:
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||
else:
|
||||
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||
if cfg.sample_packing:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
|
||||
@@ -148,7 +148,36 @@ def deduplicate_and_log_datasets(
|
||||
return dataset, other_dataset
|
||||
|
||||
|
||||
def drop_long_seq_in_dataset(
|
||||
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
"""
|
||||
Truncate samples whose sequence length is too long (> sequence_len)
|
||||
or drop those too short (< min_sequence_len).
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
results = []
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
for i, seq in enumerate(input_ids):
|
||||
length = len(seq)
|
||||
if length < min_sequence_len:
|
||||
results.append(False)
|
||||
elif length > sequence_len:
|
||||
sample["input_ids"][i] = seq[:sequence_len]
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||
if "labels" in sample:
|
||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||
results.append(True)
|
||||
else:
|
||||
results.append(True)
|
||||
return results
|
||||
|
||||
|
||||
def handle_long_seq_in_dataset(
|
||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||
) -> Dataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
@@ -192,8 +221,21 @@ def drop_long_seq_in_dataset(
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
if excess_length_strategy == "truncate":
|
||||
process_fn = functools.partial(
|
||||
truncate_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
drop_long_kwargs["desc"] = (
|
||||
f"Truncating/Filtering Sequences (target_len={sequence_len})"
|
||||
)
|
||||
else:
|
||||
process_fn = drop_long
|
||||
|
||||
dataset = dataset.filter(
|
||||
drop_long,
|
||||
process_fn,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
@@ -201,6 +243,11 @@ def drop_long_seq_in_dataset(
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||
action = (
|
||||
"truncated/filtered"
|
||||
if excess_length_strategy == "truncate"
|
||||
else "dropped"
|
||||
)
|
||||
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -8,6 +8,7 @@ from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import ParallelismConfig
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
@@ -50,7 +51,10 @@ def init_distributed_state():
|
||||
global distributed_state # pylint: disable=global-statement
|
||||
if distributed_state is None:
|
||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
try:
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def get_distributed_state() -> PartialState | None:
|
||||
@@ -290,3 +294,77 @@ def reduce_and_broadcast(fn1, fn2):
|
||||
# Use compute_and_broadcast to compute the reduced value on the main process
|
||||
# and then broadcast it to all ranks
|
||||
return compute_and_broadcast(lambda: fn2(gathered_values))
|
||||
|
||||
|
||||
def build_parallelism_config(cfg):
|
||||
pc_kwargs = _get_parallel_config_kwargs(
|
||||
get_world_size(),
|
||||
cfg.tensor_parallel_size,
|
||||
cfg.context_parallel_size,
|
||||
cfg.dp_shard_size,
|
||||
cfg.dp_replicate_size,
|
||||
bool(cfg.fsdp or cfg.fsdp_config),
|
||||
)
|
||||
|
||||
if pc_kwargs:
|
||||
parallelism_config = ParallelismConfig(
|
||||
**pc_kwargs,
|
||||
)
|
||||
device_mesh = parallelism_config.build_device_mesh("cuda")
|
||||
|
||||
return parallelism_config, device_mesh
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_parallel_config_kwargs(
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
context_parallel_size: int = 1,
|
||||
dp_shard_size: int | None = None,
|
||||
dp_replicate_size: int | None = None,
|
||||
is_fsdp: bool = False,
|
||||
):
|
||||
pc_kwargs = {}
|
||||
remaining_world_size = world_size
|
||||
|
||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
||||
|
||||
if context_parallel_size and context_parallel_size > 1:
|
||||
pc_kwargs["cp_size"] = context_parallel_size
|
||||
remaining_world_size = remaining_world_size // context_parallel_size
|
||||
|
||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
||||
if remaining_world_size > 1:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if dp_replicate_size and dp_replicate_size > 1:
|
||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
||||
|
||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
||||
if not is_fsdp:
|
||||
raise ValueError(
|
||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
||||
"Please ensure you have configured FSDP using fsdp_config."
|
||||
)
|
||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
||||
remaining_world_size = remaining_world_size // dp_shard_size
|
||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
||||
remaining_world_size = 1
|
||||
|
||||
if remaining_world_size > 1:
|
||||
raise ValueError(
|
||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
||||
f"{pc_kwargs}"
|
||||
)
|
||||
|
||||
return pc_kwargs
|
||||
|
||||
28
src/axolotl/utils/import_helper.py
Normal file
28
src/axolotl/utils/import_helper.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Helper for importing modules from strings
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
def get_cls_from_module_str(module_str: str):
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
if not isinstance(module_str, str) or not module_str.strip():
|
||||
raise ValueError("module_str must be a non-empty string")
|
||||
|
||||
parts = module_str.split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid module string format: {module_str}")
|
||||
|
||||
try:
|
||||
cls_name = parts[-1]
|
||||
module_path = ".".join(parts[:-1])
|
||||
mod = importlib.import_module(module_path)
|
||||
mod_cls = getattr(mod, cls_name)
|
||||
return mod_cls
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import module '{module_path}': {e}") from e
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Class '{cls_name}' not found in module '{module_path}': {e}"
|
||||
) from e
|
||||
@@ -4,6 +4,7 @@ import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
@@ -45,8 +46,10 @@ class RexLR(LRScheduler):
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
|
||||
initial_lr = group["lr"]
|
||||
if isinstance(initial_lr, Tensor):
|
||||
initial_lr = initial_lr.clone()
|
||||
group.setdefault("initial_lr", initial_lr)
|
||||
# Pass self.last_step as last_epoch to the parent.
|
||||
super().__init__(optimizer, last_epoch=self.last_step)
|
||||
|
||||
|
||||
@@ -110,6 +110,13 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "module to custom trainer class to use for training"
|
||||
},
|
||||
)
|
||||
|
||||
rl: RLType | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -407,6 +414,12 @@ class AxolotlInputConfig(
|
||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||
},
|
||||
)
|
||||
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
|
||||
},
|
||||
)
|
||||
eval_sequence_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -531,12 +544,6 @@ class AxolotlInputConfig(
|
||||
"description": "Whether to use flash-attention rms norm implementation - advanced use only"
|
||||
},
|
||||
)
|
||||
flash_attn_fuse_qkv: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to fuse QKV into a single operation"
|
||||
},
|
||||
)
|
||||
flash_attn_fuse_mlp: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -550,6 +557,13 @@ class AxolotlInputConfig(
|
||||
|
||||
eager_attention: bool | None = None
|
||||
|
||||
attn_implementation: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Specify a custom attention implementation, used mostly for kernels."
|
||||
},
|
||||
)
|
||||
|
||||
unsloth_cross_entropy_loss: bool | None = None
|
||||
unsloth_lora_mlp: bool | None = None
|
||||
unsloth_lora_qkv: bool | None = None
|
||||
|
||||
@@ -118,6 +118,18 @@ class SFTDataset(BaseModel):
|
||||
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
||||
},
|
||||
)
|
||||
field_thinking: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Key containing the reasoning trace (default: "reasoning_content").'
|
||||
},
|
||||
)
|
||||
template_thinking_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The key the chat template expects that indicates the reasoning trace."
|
||||
},
|
||||
)
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user