Compare commits
62 Commits
chat-templ
...
diffusion-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f75287a3a | ||
|
|
63d2280999 | ||
|
|
b210db2d15 | ||
|
|
556a69118f | ||
|
|
8569675b26 | ||
|
|
c10eb811fa | ||
|
|
0eef385b1a | ||
|
|
077b5a4358 | ||
|
|
ecbe8b2b61 | ||
|
|
234b7b3126 | ||
|
|
130ef7c51a | ||
|
|
e19be0c2d9 | ||
|
|
479a454ae3 | ||
|
|
0a9341acde | ||
|
|
d8b63804bc | ||
|
|
3156c605d4 | ||
|
|
d1de6f5f3d | ||
|
|
48b7ae1677 | ||
|
|
506e3a3907 | ||
|
|
09145de8fa | ||
|
|
e0a2523a3b | ||
|
|
3d45620008 | ||
|
|
ce20e838b5 | ||
|
|
d4d84d48af | ||
|
|
9b12c05660 | ||
|
|
686933194e | ||
|
|
d12b461d19 | ||
|
|
d6b81b3683 | ||
|
|
05f1b4b2e8 | ||
|
|
7cfc80ec77 | ||
|
|
0da6a95efa | ||
|
|
2c8497e489 | ||
|
|
f70d4de8c7 | ||
|
|
0ae06d756d | ||
|
|
2974670bf8 | ||
|
|
50f2b94d50 | ||
|
|
eb2c87b525 | ||
|
|
4db7f023c6 | ||
|
|
4273d5cf7e | ||
|
|
c5e5aba547 | ||
|
|
9d5c95db6f | ||
|
|
ca796fb56e | ||
|
|
597953bef0 | ||
|
|
39fbd3b2b5 | ||
|
|
46dfacf255 | ||
|
|
4bce713b39 | ||
|
|
d09290f2f4 | ||
|
|
e442ff22aa | ||
|
|
ba3dba3e4f | ||
|
|
97e86c6d47 | ||
|
|
784f8c0e95 | ||
|
|
e3177c3210 | ||
|
|
70faea331f | ||
|
|
8021c718ce | ||
|
|
42f5e6f9e9 | ||
|
|
ab49d16e34 | ||
|
|
33d094721c | ||
|
|
a54c1be972 | ||
|
|
5691992d34 | ||
|
|
e758343cac | ||
|
|
deac7b18a1 | ||
|
|
10946afae7 |
7
.github/CONTRIBUTING.md
vendored
7
.github/CONTRIBUTING.md
vendored
@@ -57,6 +57,13 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
|
|||||||
5. Push your branch to your fork on GitHub.
|
5. Push your branch to your fork on GitHub.
|
||||||
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
|
||||||
|
|
||||||
|
#### Skipping CI Checks
|
||||||
|
|
||||||
|
You can skip certain CI checks by including specific keywords in your commit messages:
|
||||||
|
|
||||||
|
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
|
||||||
|
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
|
||||||
|
|
||||||
## Style Guidelines
|
## Style Guidelines
|
||||||
|
|
||||||
### Code Style
|
### Code Style
|
||||||
|
|||||||
27
.github/workflows/base.yml
vendored
27
.github/workflows/base.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
@@ -64,9 +64,16 @@ jobs:
|
|||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: nightly
|
pytorch: 2.8.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base-nightly"
|
dockerfile: "Dockerfile-base"
|
||||||
|
# - cuda: "128"
|
||||||
|
# cuda_version: 12.8.1
|
||||||
|
# cudnn_version: ""
|
||||||
|
# python_version: "3.11"
|
||||||
|
# pytorch: nightly
|
||||||
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
# dockerfile: "Dockerfile-base-nightly"
|
||||||
# # "next" is for release candidates of pytorch
|
# # "next" is for release candidates of pytorch
|
||||||
# - cuda: "128"
|
# - cuda: "128"
|
||||||
# cuda_version: 12.8.1
|
# cuda_version: 12.8.1
|
||||||
@@ -122,6 +129,13 @@ jobs:
|
|||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -129,6 +143,13 @@ jobs:
|
|||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.8.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
dockerfile: "Dockerfile-uv-base"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
23
.github/workflows/main.yml
vendored
23
.github/workflows/main.yml
vendored
@@ -24,12 +24,13 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -97,6 +98,12 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
@@ -150,6 +157,18 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras:
|
||||||
|
is_latest:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.1
|
||||||
|
axolotl_extras: vllm
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
49
.github/workflows/tests.yml
vendored
49
.github/workflows/tests.yml
vendored
@@ -105,7 +105,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
@@ -179,21 +180,52 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v --durations=10 tests/patched/
|
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
|
gate-skip-e2e:
|
||||||
|
needs: [pre-commit, pytest, pytest-sdist]
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
skip: ${{ steps.compute.outputs.skip }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@v7
|
||||||
|
id: compute
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const token = /\[skip-e2e\]/i;
|
||||||
|
let msg = '';
|
||||||
|
if (context.eventName === 'push') {
|
||||||
|
msg = context.payload.head_commit?.message || '';
|
||||||
|
} else if (context.eventName === 'pull_request') {
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const prNumber = context.payload.pull_request.number;
|
||||||
|
const commits = await github.paginate(
|
||||||
|
github.rest.pulls.listCommits,
|
||||||
|
{ owner, repo, pull_number: prNumber, per_page: 100 }
|
||||||
|
);
|
||||||
|
msg = commits.at(-1)?.commit?.message || '';
|
||||||
|
}
|
||||||
|
const title = context.payload.pull_request?.title || '';
|
||||||
|
const body = context.payload.pull_request?.body || '';
|
||||||
|
const skip = token.test(msg) || token.test(title) || token.test(body);
|
||||||
|
core.setOutput('skip', String(skip));
|
||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests-1st:
|
||||||
# Run this job first as a gate for running the remainder of the test matrix
|
# Run this job first as a gate for running the remainder of the test matrix
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
if: >
|
||||||
|
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||||
|
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||||
|
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
needs: [pre-commit, pytest, pytest-sdist]
|
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -239,13 +271,16 @@ jobs:
|
|||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
if: >
|
||||||
|
github.repository_owner == 'axolotl-ai-cloud' &&
|
||||||
|
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
|
||||||
|
needs.gate-skip-e2e.outputs.skip != 'true'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
# Only run the remainder of the matrix if the first e2e check passed;
|
# Only run the remainder of the matrix if the first e2e check passed;
|
||||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ default_language_version:
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v5.0.0
|
rev: v6.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -23,11 +23,11 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
rev: v3.3.7
|
rev: v3.3.8
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.17.0
|
rev: v1.17.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
|
|||||||
@@ -185,7 +185,6 @@ datasets:
|
|||||||
| `flash_attention` | `false` | Use flash attention |
|
| `flash_attention` | `false` | Use flash attention |
|
||||||
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
|
||||||
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
|
||||||
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
|
|
||||||
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
|
||||||
| `sdp_attention` | `false` | Use scaled dot product |
|
| `sdp_attention` | `false` | Use scaled dot product |
|
||||||
| `s2_attention` | `false` | Use shifted sparse attention |
|
| `s2_attention` | `false` | Use shifted sparse attention |
|
||||||
|
|||||||
@@ -296,7 +296,6 @@
|
|||||||
# flash_attention:
|
# flash_attention:
|
||||||
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||||
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||||
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
|
||||||
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||||
# # Whether to use scaled-dot-product attention
|
# # Whether to use scaled-dot-product attention
|
||||||
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
@@ -541,7 +540,6 @@ xformers_attention: ${XFORMERS_ATTENTION}
|
|||||||
flash_attention: ${FLASH_ATTENTION}
|
flash_attention: ${FLASH_ATTENTION}
|
||||||
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
|
||||||
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
|
||||||
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
|
|
||||||
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
|
||||||
sdp_attention: ${SDP_ATTENTION}
|
sdp_attention: ${SDP_ATTENTION}
|
||||||
s2_attention: ${S2_ATTENTION}
|
s2_attention: ${S2_ATTENTION}
|
||||||
|
|||||||
10
CITATION.cff
Normal file
10
CITATION.cff
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
type: software
|
||||||
|
title: "Axolotl: Post-Training for AI Models"
|
||||||
|
message: "If you use this software, please cite it as below."
|
||||||
|
authors:
|
||||||
|
- name: "Axolotl maintainers and contributors"
|
||||||
|
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
|
||||||
|
url: "https://axolotl.ai/"
|
||||||
|
license: Apache-2.0
|
||||||
|
date-released: "2023-05-30"
|
||||||
33
README.md
33
README.md
@@ -25,17 +25,28 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
|
- 2025/07:
|
||||||
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
|
||||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
|
||||||
|
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||||
|
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||||
|
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
|
||||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Expand older updates</summary>
|
||||||
|
|
||||||
|
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||||
|
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## ✨ Overview
|
## ✨ Overview
|
||||||
|
|
||||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||||
@@ -138,6 +149,20 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
|
|||||||
|
|
||||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||||
|
|
||||||
|
## 📝 Citing Axolotl
|
||||||
|
|
||||||
|
If you use Axolotl in your research or projects, please cite it as follows:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@software{axolotl,
|
||||||
|
title = {Axolotl: Post-Training for AI Models},
|
||||||
|
author = {{Axolotl maintainers and contributors}},
|
||||||
|
url = {https://github.com/axolotl-ai-cloud/axolotl},
|
||||||
|
license = {Apache-2.0},
|
||||||
|
year = {2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## 📜 License
|
## 📜 License
|
||||||
|
|
||||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||||
|
|||||||
10
TODO.md
10
TODO.md
@@ -1,10 +0,0 @@
|
|||||||
# todo list
|
|
||||||
|
|
||||||
- [] Validation of parameters for combinations that won't work
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## things that are known not to work
|
|
||||||
|
|
||||||
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
|
|
||||||
- adamw_bnb_8bit doesn't play well with FSDP offload
|
|
||||||
@@ -274,6 +274,7 @@ website:
|
|||||||
- docs/dataset_preprocessing.qmd
|
- docs/dataset_preprocessing.qmd
|
||||||
- docs/multipack.qmd
|
- docs/multipack.qmd
|
||||||
- docs/mixed_precision.qmd
|
- docs/mixed_precision.qmd
|
||||||
|
- docs/optimizers.qmd
|
||||||
|
|
||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ WORKDIR /workspace
|
|||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
|
||||||
python3 -m pip cache purge
|
python3 -m pip cache purge
|
||||||
|
|
||||||
|
|||||||
@@ -212,10 +212,11 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
|||||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
Example config for Llama4:
|
||||||
```yaml
|
```yaml
|
||||||
chat_template: llama4
|
chat_template: llama4
|
||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: Nanobit/text-tools-2k-test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
# field_tools: tools # default is `tools`
|
# field_tools: tools # default is `tools`
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -13,10 +13,13 @@ format:
|
|||||||
- [Pixtral](#sec-pixtral)
|
- [Pixtral](#sec-pixtral)
|
||||||
- [Llava-1.5](#sec-llava-15)
|
- [Llava-1.5](#sec-llava-15)
|
||||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
|
- [Voxtral](#sec-voxtral)
|
||||||
- [Gemma-3](#sec-gemma-3)
|
- [Gemma-3](#sec-gemma-3)
|
||||||
- [Gemma-3n](#sec-gemma-3n)
|
- [Gemma-3n](#sec-gemma-3n)
|
||||||
- [Qwen2-VL](#sec-qwen2-vl)
|
- [Qwen2-VL](#sec-qwen2-vl)
|
||||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||||
|
- [SmolVLM2](#sec-smolvlm2)
|
||||||
|
- [LFM2-VL](#sec-lfm2-vl)
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
@@ -31,7 +34,7 @@ skip_prepare_dataset: true
|
|||||||
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||||
sample_packing: false # not yet supported with multimodal
|
sample_packing: false # not yet supported with multimodal
|
||||||
|
|
||||||
chat_template: # see in next section
|
chat_template: # see in next section if specified
|
||||||
|
|
||||||
# example dataset
|
# example dataset
|
||||||
datasets:
|
datasets:
|
||||||
@@ -97,6 +100,16 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
|||||||
chat_template: mistral_v7_tekken
|
chat_template: mistral_v7_tekken
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Voxtral {#sec-voxtral}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Voxtral-Mini-3B-2507
|
||||||
|
```
|
||||||
|
|
||||||
### Gemma-3 {#sec-gemma-3}
|
### Gemma-3 {#sec-gemma-3}
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
@@ -143,6 +156,26 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
|||||||
chat_template: qwen2_vl # same as qwen2-vl
|
chat_template: qwen2_vl # same as qwen2-vl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### SmolVLM2 {#sec-smolvlm2}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### LFM2-VL {#sec-lfm2-vl}
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: LiquidAI/LFM2-VL-450M
|
||||||
|
```
|
||||||
|
|
||||||
## Dataset Format
|
## Dataset Format
|
||||||
|
|
||||||
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||||
@@ -181,6 +214,20 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
### Video
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
|
||||||
|
This is not well tested at the moment. We welcome contributors!
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
|
||||||
|
|
||||||
|
- `"path": "/path/to/video.mp4"`
|
||||||
|
- `"url": "https://example.com/video.mp4"`
|
||||||
|
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
|
||||||
|
|
||||||
### Example
|
### Example
|
||||||
|
|
||||||
Here is an example of a multi-modal dataset:
|
Here is an example of a multi-modal dataset:
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
# N-D Parallelism
|
---
|
||||||
|
title: "N-D Parallelism (Beta)"
|
||||||
|
---
|
||||||
|
|
||||||
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
|
||||||
|
|
||||||
@@ -71,6 +73,10 @@ Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size
|
|||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
|
||||||
|
:::
|
||||||
|
|
||||||
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
|
||||||
- You want FSDP within each node and DDP across nodes.
|
- You want FSDP within each node and DDP across nodes.
|
||||||
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
|
||||||
@@ -95,7 +101,7 @@ This matrix describes how different parallelism methods can be combined in Axolo
|
|||||||
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
|
||||||
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
|
||||||
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
|
||||||
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP/CP without FSDP is inefficient and complex. You should use FSDP instead (`dp_shard_size > 1`). |
|
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
|
||||||
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
|
||||||
|
|
||||||
- `tp_size` refers to `tensor_parallel_size`
|
- `tp_size` refers to `tensor_parallel_size`
|
||||||
|
|||||||
129
docs/optimizers.qmd
Normal file
129
docs/optimizers.qmd
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
---
|
||||||
|
title: Optimizers
|
||||||
|
description: Configuring optimizers
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
|
||||||
|
|
||||||
|
Here is a list of optimizers supported by transformers as of `v4.54.0`:
|
||||||
|
|
||||||
|
- `adamw_torch`
|
||||||
|
- `adamw_torch_fused`
|
||||||
|
- `adamw_torch_xla`
|
||||||
|
- `adamw_torch_npu_fused`
|
||||||
|
- `adamw_apex_fused`
|
||||||
|
- `adafactor`
|
||||||
|
- `adamw_anyprecision`
|
||||||
|
- `adamw_torch_4bit`
|
||||||
|
- `adamw_torch_8bit`
|
||||||
|
- `ademamix`
|
||||||
|
- `sgd`
|
||||||
|
- `adagrad`
|
||||||
|
- `adamw_bnb_8bit`
|
||||||
|
- `adamw_8bit` # alias for adamw_bnb_8bit
|
||||||
|
- `ademamix_8bit`
|
||||||
|
- `lion_8bit`
|
||||||
|
- `lion_32bit`
|
||||||
|
- `paged_adamw_32bit`
|
||||||
|
- `paged_adamw_8bit`
|
||||||
|
- `paged_ademamix_32bit`
|
||||||
|
- `paged_ademamix_8bit`
|
||||||
|
- `paged_lion_32bit`
|
||||||
|
- `paged_lion_8bit`
|
||||||
|
- `rmsprop`
|
||||||
|
- `rmsprop_bnb`
|
||||||
|
- `rmsprop_bnb_8bit`
|
||||||
|
- `rmsprop_bnb_32bit`
|
||||||
|
- `galore_adamw`
|
||||||
|
- `galore_adamw_8bit`
|
||||||
|
- `galore_adafactor`
|
||||||
|
- `galore_adamw_layerwise`
|
||||||
|
- `galore_adamw_8bit_layerwise`
|
||||||
|
- `galore_adafactor_layerwise`
|
||||||
|
- `lomo`
|
||||||
|
- `adalomo`
|
||||||
|
- `grokadamw`
|
||||||
|
- `schedule_free_radam`
|
||||||
|
- `schedule_free_adamw`
|
||||||
|
- `schedule_free_sgd`
|
||||||
|
- `apollo_adamw`
|
||||||
|
- `apollo_adamw_layerwise`
|
||||||
|
- `stable_adamw`
|
||||||
|
|
||||||
|
|
||||||
|
## Custom Optimizers
|
||||||
|
|
||||||
|
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
|
||||||
|
|
||||||
|
### optimi_adamw
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: optimi_adamw
|
||||||
|
```
|
||||||
|
|
||||||
|
### ao_adamw_4bit
|
||||||
|
|
||||||
|
Deprecated: Please use `adamw_torch_4bit`.
|
||||||
|
|
||||||
|
### ao_adamw_8bit
|
||||||
|
|
||||||
|
Deprecated: Please use `adamw_torch_8bit`.
|
||||||
|
|
||||||
|
### ao_adamw_fp8
|
||||||
|
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: ao_adamw_fp8
|
||||||
|
```
|
||||||
|
|
||||||
|
### adopt_adamw
|
||||||
|
|
||||||
|
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
|
||||||
|
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: adopt_adamw
|
||||||
|
```
|
||||||
|
|
||||||
|
### came_pytorch
|
||||||
|
|
||||||
|
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
|
||||||
|
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: came_pytorch
|
||||||
|
|
||||||
|
# optional args (defaults below)
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.999
|
||||||
|
adam_beta3: 0.9999
|
||||||
|
adam_epsilon: 1e-30
|
||||||
|
adam_epsilon2: 1e-16
|
||||||
|
```
|
||||||
|
|
||||||
|
### muon
|
||||||
|
|
||||||
|
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
|
||||||
|
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: muon
|
||||||
|
```
|
||||||
|
|
||||||
|
### dion
|
||||||
|
|
||||||
|
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
|
||||||
|
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
|
||||||
|
|
||||||
|
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
|
||||||
|
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
|
||||||
|
Note: Implementation written for PyTorch 2.7+ for DTensor
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: dion
|
||||||
|
dion_lr: 0.01
|
||||||
|
dion_momentum: 0.95
|
||||||
|
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
|
||||||
|
```
|
||||||
58
examples/LiquidAI/README.md
Normal file
58
examples/LiquidAI/README.md
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
|
||||||
|
|
||||||
|
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
|
||||||
|
|
||||||
|
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from pip:
|
||||||
|
```bash
|
||||||
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
|
pip3 install packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run one of the finetuning examples below.
|
||||||
|
|
||||||
|
**LFM2**
|
||||||
|
```bash
|
||||||
|
# FFT SFT (1x48GB @ 25GiB)
|
||||||
|
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
**LFM2-VL**
|
||||||
|
```bash
|
||||||
|
# LoRA SFT (1x48GB @ 2.7GiB)
|
||||||
|
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||||
|
```bash
|
||||||
|
pip uninstall -y causal-conv1d
|
||||||
|
```
|
||||||
|
|
||||||
|
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- **Dataset Formats**:
|
||||||
|
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
|
||||||
|
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
@@ -2,7 +2,6 @@ base_model: LiquidAI/LFM2-350M
|
|||||||
|
|
||||||
chunked_cross_entropy: true
|
chunked_cross_entropy: true
|
||||||
|
|
||||||
chat_template: tokenizer_default
|
|
||||||
eot_tokens:
|
eot_tokens:
|
||||||
- "<|im_end|>"
|
- "<|im_end|>"
|
||||||
datasets:
|
datasets:
|
||||||
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
58
examples/LiquidAI/lfm2-vl-lora.yaml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: LiquidAI/LFM2-VL-450M
|
||||||
|
trust_remote_code: true
|
||||||
|
model_type: AutoModelForImageTextToText
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
53
examples/arcee/README.md
Normal file
53
examples/arcee/README.md
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# Finetune ArceeAI's AFM with Axolotl
|
||||||
|
|
||||||
|
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/arcee/afm-4.5b-qlora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 7.8GiB VRAM.
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
64
examples/arcee/afm-4.5b-qlora.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: arcee-ai/AFM-4.5B
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -47,7 +47,6 @@ logging_steps: 1
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
flash_attn_fuse_qkv: false
|
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -10,17 +10,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Run the finetuning example:
|
||||||
|
|||||||
52
examples/distributed-parallel/README.md
Normal file
52
examples/distributed-parallel/README.md
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# ND Parallelism Examples
|
||||||
|
|
||||||
|
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
2. Run the command below:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
|
||||||
|
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
|
||||||
|
|
||||||
|
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
|
||||||
|
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example Configurations
|
||||||
|
|
||||||
|
### Single Node (8 GPUs)
|
||||||
|
|
||||||
|
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
|
||||||
|
- Uses all 3 parallelism dimensions on a single node
|
||||||
|
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
dp_shard_size: 2 # FSDP across 2 GPUs
|
||||||
|
tensor_parallel_size: 2 # TP across 2 GPUs
|
||||||
|
context_parallel_size: 2 # CP across 2 GPUs
|
||||||
|
# Total: 2 × 2 × 2 = 8 GPUs
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Node
|
||||||
|
|
||||||
|
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
|
||||||
|
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
|
||||||
|
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
dp_shard_size: 4 # FSDP within each 4-GPU group
|
||||||
|
tensor_parallel_size: 2 # TP within each node
|
||||||
|
dp_replicate_size: 2 # DDP across 2 groups
|
||||||
|
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Learn More
|
||||||
|
|
||||||
|
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
|
||||||
|
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
|
||||||
|
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
47
examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
base_model: meta-llama/Llama-3.1-8B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
dp_shard_size: 4
|
||||||
|
dp_replicate_size: 2
|
||||||
|
tensor_parallel_size: 2
|
||||||
|
# context_parallel_size: 2
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/ndp-out/
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-6
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
46
examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
base_model: Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
dp_shard_size: 2
|
||||||
|
# dp_replicate_size: 1
|
||||||
|
context_parallel_size: 2
|
||||||
|
tensor_parallel_size: 2
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
|
||||||
|
output_dir: ./outputs/ndp-out/
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
sample_packing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1 # must be 1 when using context parallel
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-6
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
@@ -4,17 +4,14 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||||
|
|||||||
105
examples/gpt-oss/README.md
Normal file
105
examples/gpt-oss/README.md
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# Finetune OpenAI's GPT-OSS with Axolotl
|
||||||
|
|
||||||
|
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LoRA SFT linear layers (1x48GB @ ~44GiB)
|
||||||
|
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
|
||||||
|
|
||||||
|
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
|
||||||
|
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
|
||||||
|
|
||||||
|
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
|
||||||
|
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
|
||||||
|
|
||||||
|
### Training 120B
|
||||||
|
|
||||||
|
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
|
||||||
|
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
|
||||||
|
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
|
||||||
|
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
|
||||||
|
```
|
||||||
|
|
||||||
|
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
|
||||||
|
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
|
||||||
|
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
|
||||||
|
weights to `{output_dir}/merged`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
|
||||||
|
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Inferencing your fine-tuned model
|
||||||
|
|
||||||
|
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
|
||||||
|
for more information about using a special vllm-openai docker image for inferencing with vLLM.
|
||||||
|
|
||||||
|
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
|
||||||
|
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tool use
|
||||||
|
|
||||||
|
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
|
||||||
|
|
||||||
|
Here is an example dataset config:
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: Nanobit/text-tools-2k-test
|
||||||
|
type: chat_template
|
||||||
|
```
|
||||||
|
|
||||||
|
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
|
||||||
|
|
||||||
|
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
68
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
68
examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
|
||||||
|
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
|
||||||
|
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
|
||||||
|
|
||||||
|
use_kernels: false
|
||||||
|
|
||||||
|
dp_shard_size: 16 # requires 2x8xH100 nodes
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
|
type: chat_template
|
||||||
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.03
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: true
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
58
examples/gpt-oss/gpt-oss-20b-fft-deepspeed-zero3.yaml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: openai/gpt-oss-20b
|
||||||
|
use_kernels: false
|
||||||
|
model_quantization_config: Mxfp4Config
|
||||||
|
model_quantization_config_kwargs:
|
||||||
|
dequantize: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
|
type: chat_template
|
||||||
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.03
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
|
|
||||||
|
# choose the zero3 configuration that best fits your system capabilities
|
||||||
|
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||||
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
68
examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: openai/gpt-oss-20b
|
||||||
|
use_kernels: true
|
||||||
|
model_quantization_config: Mxfp4Config
|
||||||
|
model_quantization_config_kwargs:
|
||||||
|
dequantize: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
|
type: chat_template
|
||||||
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.03
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: true
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
# cpu_ram_efficient_loading: true
|
||||||
|
|
||||||
|
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
|
||||||
|
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`
|
||||||
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
64
examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: openai/gpt-oss-20b
|
||||||
|
use_kernels: false
|
||||||
|
model_quantization_config: Mxfp4Config
|
||||||
|
model_quantization_config_kwargs:
|
||||||
|
dequantize: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
|
type: chat_template
|
||||||
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
warmup_ratio: 0.03
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
|
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
state_dict_type: SHARDED_STATE_DICT
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: GptOssDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
|
# cpu_ram_efficient_loading: true
|
||||||
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
67
examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: openai/gpt-oss-20b
|
||||||
|
use_kernels: true
|
||||||
|
model_quantization_config: Mxfp4Config
|
||||||
|
model_quantization_config_kwargs:
|
||||||
|
dequantize: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/Multilingual-Thinking
|
||||||
|
type: chat_template
|
||||||
|
field_thinking: thinking
|
||||||
|
template_thinking_key: thinking
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/gpt-oss-out/
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
# TODO: not supported for now, see peft#2710
|
||||||
|
#lora_target_parameters: # target the experts in the last two layers
|
||||||
|
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||||
|
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||||
|
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
|
||||||
|
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_torch_8bit
|
||||||
|
lr_scheduler: constant_with_warmup
|
||||||
|
learning_rate: 2e-4
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
flash_attention: true
|
||||||
|
attn_implementation: kernels-community/vllm-flash-attn3
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
activation_offloading: true
|
||||||
|
|
||||||
|
logging_steps: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
eot_tokens:
|
||||||
|
- "<|end|>"
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
# Liquid Foundation Models 2
|
|
||||||
|
|
||||||
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
|
|
||||||
```
|
|
||||||
@@ -45,7 +45,6 @@ logging_steps: 1
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
flash_attn_fuse_qkv: false
|
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ logging_steps: 1
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
flash_attn_fuse_qkv: false
|
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
57
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
57
examples/llama-3/diffusion-3.2-1b-pretrain.yaml
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: wikitext
|
||||||
|
name: wikitext-103-raw-v1
|
||||||
|
type: completion
|
||||||
|
field: text
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- diffusion.DiffusionPlugin
|
||||||
|
noise_schedule: cosine
|
||||||
|
min_mask_ratio: 0.15
|
||||||
|
max_mask_ratio: 0.85
|
||||||
|
eps: 5e-4
|
||||||
|
importance_weighting: true
|
||||||
|
mask_token_id: 128002
|
||||||
|
generate_samples: true
|
||||||
|
generation_interval: 10
|
||||||
|
|
||||||
|
output_dir: ./outputs/model-out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 4
|
||||||
|
max_steps: 10000
|
||||||
|
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 3e-4
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
sdp_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 1000
|
||||||
|
|
||||||
|
save_strategy: steps
|
||||||
|
save_steps: 1000
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
58
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
58
examples/llama-3/diffusion-3.2-1b-sft.yaml
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
|
type: alpaca
|
||||||
|
val_set_size: 0.05
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- diffusion.DiffusionPlugin
|
||||||
|
noise_schedule: cosine
|
||||||
|
min_mask_ratio: 0.1
|
||||||
|
max_mask_ratio: 0.9
|
||||||
|
num_diffusion_steps: 128
|
||||||
|
eps: 1e-3
|
||||||
|
importance_weighting: true
|
||||||
|
mask_token_id: 128002
|
||||||
|
|
||||||
|
output_dir: ./outputs/model-out
|
||||||
|
|
||||||
|
sequence_len: 512
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
optimizer: adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 1e-5
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
sdp_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 1000
|
||||||
|
|
||||||
|
save_strategy: steps
|
||||||
|
eval_strategy: steps
|
||||||
|
save_steps: 500
|
||||||
|
eval_steps: 500
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -8,17 +8,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run the finetuning example:
|
2. Run the finetuning example:
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ sequence_len: 2048
|
|||||||
sample_packing: true
|
sample_packing: true
|
||||||
eval_sample_packing: false
|
eval_sample_packing: false
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
66
examples/slurm/README.md
Normal file
66
examples/slurm/README.md
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# SLURM Multi-Node Training
|
||||||
|
|
||||||
|
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Access to a SLURM cluster with GPU nodes
|
||||||
|
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Standard SLURM Clusters
|
||||||
|
|
||||||
|
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
|
||||||
|
2. Place your Axolotl config file (`train.yaml`) in the same directory.
|
||||||
|
3. Set the appropriate environment variables for the job:
|
||||||
|
```bash
|
||||||
|
export HF_TOKEN="your-huggingface-token"
|
||||||
|
|
||||||
|
# metric tracking
|
||||||
|
# export WANDB_API_KEY="your-wandb-api-key"
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
4. Submit the job:
|
||||||
|
```bash
|
||||||
|
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
|
||||||
|
```
|
||||||
|
|
||||||
|
Where:
|
||||||
|
- `NUM_NODES`: Number of nodes to use
|
||||||
|
- `NUM_TRAINERS`: GPUs per node (typically 8)
|
||||||
|
- `PRIMARY_ADDR`: Hostname/IP of the master node
|
||||||
|
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
|
||||||
|
|
||||||
|
5. (Optional) Run other slurm commands:
|
||||||
|
```bash
|
||||||
|
# check job info
|
||||||
|
scontrol show job axolotl-cli
|
||||||
|
|
||||||
|
# check job queue
|
||||||
|
squeue
|
||||||
|
|
||||||
|
# check cluster status
|
||||||
|
sinfo
|
||||||
|
```
|
||||||
|
|
||||||
|
### RunPod Instant Clusters
|
||||||
|
|
||||||
|
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
|
||||||
|
|
||||||
|
1. **Deploy a SLURM Cluster**:
|
||||||
|
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
|
||||||
|
- Click "Create a Cluster"
|
||||||
|
- Choose your GPU type, node count, and region
|
||||||
|
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
|
||||||
|
- Deploy the cluster
|
||||||
|
|
||||||
|
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
|
||||||
|
|
||||||
|
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
|
||||||
|
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)
|
||||||
20
examples/slurm/axolotl.slurm
Normal file
20
examples/slurm/axolotl.slurm
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
|
||||||
|
# export HF_TOKEN="..."
|
||||||
|
# export WANDB_API_KEY="..."
|
||||||
|
#
|
||||||
|
|
||||||
|
# ---------- SBATCH commands ---------- #
|
||||||
|
#SBATCH --job-name=axolotl-slurm-multinode
|
||||||
|
#SBATCH --ntasks-per-node=1
|
||||||
|
#SBATCH --nodes=$NUM_NODES
|
||||||
|
#SBATCH --gpus-per-task=8
|
||||||
|
#SBATCH --cpus-per-task=128
|
||||||
|
|
||||||
|
export TORCH_DIST_INIT_BARRIER=0
|
||||||
|
|
||||||
|
srun axolotl preprocess train.yaml
|
||||||
|
|
||||||
|
srun axolotl train train.yaml --launcher torchrun -- \
|
||||||
|
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
|
||||||
|
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"
|
||||||
49
examples/smolvlm2/README.md
Normal file
49
examples/smolvlm2/README.md
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Finetune SmolVLM2 with Axolotl
|
||||||
|
|
||||||
|
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
|
||||||
|
|
||||||
|
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from pip:
|
||||||
|
```bash
|
||||||
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
|
pip3 install packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Install an extra dependency:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip3 install num2words==0.5.14
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LoRA SFT (1x48GB @ 6.8GiB)
|
||||||
|
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## TIPS
|
||||||
|
|
||||||
|
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||||
|
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
56
examples/smolvlm2/smolvlm2-2B-lora.yaml
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
|
||||||
|
trust_remote_code: true
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
|
||||||
|
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||||
@@ -6,17 +6,14 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
|
|
||||||
## Getting started
|
## Getting started
|
||||||
|
|
||||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Please install the below.
|
2. Please install the below.
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.46.0
|
bitsandbytes==0.47.0
|
||||||
triton>=3.0.0
|
# triton 3.4.0 is not compatible with CCE
|
||||||
|
triton>=3.0.0,<3.4.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
@@ -12,19 +13,21 @@ liger-kernel==0.6.1
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft==0.16.0
|
peft==0.17.0
|
||||||
transformers==4.54.1
|
transformers==4.55.2
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate.git@9359a0194f210624f1e6e85c3d838fdd55c11152
|
accelerate==1.10.0
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.20.0
|
trl==0.21.0
|
||||||
hf_xet==1.1.5
|
hf_xet==1.1.5
|
||||||
|
kernels==0.9.0
|
||||||
|
trackio
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
sentencepiece
|
sentencepiece
|
||||||
gradio==5.23.3
|
gradio==5.41.1
|
||||||
|
|
||||||
modal==1.0.2
|
modal==1.0.2
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
@@ -66,6 +69,6 @@ torchao==0.12.0
|
|||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.5
|
||||||
|
|
||||||
mistral-common==1.8.3
|
mistral-common==1.8.3
|
||||||
|
|||||||
@@ -44,8 +44,13 @@ add_keys_to_authorized() {
|
|||||||
chmod 700 -R ~/.ssh
|
chmod 700 -R ~/.ssh
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Set SSH port
|
||||||
|
if [ ! -z "$SSH_PORT" ]; then
|
||||||
|
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
|
||||||
|
fi
|
||||||
|
|
||||||
if [[ $PUBLIC_KEY ]]; then
|
if [[ $PUBLIC_KEY ]]; then
|
||||||
# runpod
|
# runpod, prime intellect
|
||||||
add_keys_to_authorized "$PUBLIC_KEY"
|
add_keys_to_authorized "$PUBLIC_KEY"
|
||||||
# Start the SSH service in the background
|
# Start the SSH service in the background
|
||||||
service ssh start
|
service ssh start
|
||||||
@@ -76,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
|
|||||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# start the runpod slurm init
|
||||||
|
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
|
||||||
|
|
||||||
|
if [[ -f "$SLURM_INIT" ]]; then
|
||||||
|
echo "[entrypoint] running $SLURM_INIT..."
|
||||||
|
bash "$SLURM_INIT"
|
||||||
|
fi
|
||||||
|
|
||||||
# Execute the passed arguments (CMD)
|
# Execute the passed arguments (CMD)
|
||||||
exec "$@"
|
exec "$@"
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.12.0.dev"
|
__version__ = "0.13.0.dev"
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ class VllmServeCliArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of tensor parallel workers to use."},
|
metadata={"help": "Number of tensor parallel workers to use."},
|
||||||
)
|
)
|
||||||
|
data_parallel_size: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
|
||||||
|
},
|
||||||
|
)
|
||||||
host: Optional[str] = field(
|
host: Optional[str] = field(
|
||||||
default=None, # nosec B104
|
default=None, # nosec B104
|
||||||
metadata={"help": "Host address to run the server on."},
|
metadata={"help": "Host address to run the server on."},
|
||||||
|
|||||||
@@ -153,15 +153,14 @@ def prepare_plugins(cfg: DictDefault):
|
|||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
for plugin_name in cfg["plugins"]:
|
for plugin_name in cfg["plugins"]:
|
||||||
plugin_manager.register(plugin_name)
|
plugin_manager.register(plugin_name)
|
||||||
|
for plugin in plugin_manager.plugins.values():
|
||||||
|
plugin.register(cfg)
|
||||||
|
|
||||||
|
|
||||||
def plugin_set_cfg(cfg: DictDefault):
|
def plugin_set_cfg(cfg: DictDefault):
|
||||||
if cfg.get("plugins"):
|
if cfg.get("plugins"):
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
plugin_manager.cfg = cfg
|
plugin_manager.cfg = cfg
|
||||||
# now that we have the finalized cfg, register the plugins individually
|
|
||||||
for plugin in plugin_manager.plugins.values():
|
|
||||||
plugin.register(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def load_cfg(
|
def load_cfg(
|
||||||
|
|||||||
@@ -123,9 +123,10 @@ def train(
|
|||||||
_launcher = None if kwargs.get("use_ray") else launcher
|
_launcher = None if kwargs.get("use_ray") else launcher
|
||||||
|
|
||||||
# Process each configuration
|
# Process each configuration
|
||||||
for cfg_file in generate_config_files(config, sweep):
|
for cfg_file, is_group in generate_config_files(config, sweep):
|
||||||
try:
|
try:
|
||||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
use_exec = is_group is not True
|
||||||
|
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||||
if not sweep:
|
if not sweep:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import fire
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed.checkpoint as dist_cp
|
import torch.distributed.checkpoint as dist_cp
|
||||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||||
|
from accelerate import PartialState
|
||||||
from accelerate.utils import (
|
from accelerate.utils import (
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
SAFE_WEIGHTS_NAME,
|
SAFE_WEIGHTS_NAME,
|
||||||
@@ -23,6 +24,7 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
|||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -143,7 +145,6 @@ def merge_fsdp_weights(
|
|||||||
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
|
||||||
"""
|
"""
|
||||||
checkpoint_dir_ = Path(checkpoint_dir)
|
checkpoint_dir_ = Path(checkpoint_dir)
|
||||||
from accelerate.state import PartialState
|
|
||||||
|
|
||||||
if not is_torch_version(">=", "2.3.0"):
|
if not is_torch_version(">=", "2.3.0"):
|
||||||
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||||
@@ -180,7 +181,6 @@ def merge_fsdp_weights(
|
|||||||
if remove_checkpoint_dir:
|
if remove_checkpoint_dir:
|
||||||
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
||||||
shutil.rmtree(checkpoint_dir_)
|
shutil.rmtree(checkpoint_dir_)
|
||||||
state.wait_for_everyone()
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
@@ -195,11 +195,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
|
||||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
if not fsdp_dir.exists():
|
||||||
|
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
|
||||||
|
if checkpoint_dir:
|
||||||
|
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||||
|
if not fsdp_dir.exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = str(Path(parsed_cfg.output_dir) / "merged")
|
||||||
merge_fsdp_weights(
|
merge_fsdp_weights(
|
||||||
checkpoint_dir=str(fsdp_dir),
|
checkpoint_dir=str(fsdp_dir),
|
||||||
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
|
output_path=output_path,
|
||||||
safe_serialization=True,
|
safe_serialization=True,
|
||||||
)
|
)
|
||||||
|
state = PartialState()
|
||||||
|
state.wait_for_everyone()
|
||||||
|
LOG.info(
|
||||||
|
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||||
|
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||||
|
main_process_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Any, Iterator, Literal
|
from typing import Any, Iterator, Literal
|
||||||
|
|
||||||
@@ -64,10 +65,18 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
|||||||
return cmd
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str, bool]]:
|
||||||
"""Generate list of configuration files to process."""
|
"""
|
||||||
|
Generate list of configuration files to process. Yields a tuple of the configuration file name and a boolean indicating
|
||||||
|
whether this is a group of configurations (i.e., a sweep).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Base configuration file
|
||||||
|
sweep: Sweep configuration file
|
||||||
|
"""
|
||||||
|
|
||||||
if not sweep:
|
if not sweep:
|
||||||
yield config
|
yield config, False
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load sweep and base configurations
|
# Load sweep and base configurations
|
||||||
@@ -78,6 +87,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
|||||||
|
|
||||||
# Generate all possible configurations
|
# Generate all possible configurations
|
||||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||||
|
is_group = len(permutations) > 1
|
||||||
for permutation in permutations:
|
for permutation in permutations:
|
||||||
# pylint: disable=consider-using-with
|
# pylint: disable=consider-using-with
|
||||||
temp_file = tempfile.NamedTemporaryFile(
|
temp_file = tempfile.NamedTemporaryFile(
|
||||||
@@ -88,7 +98,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
|||||||
)
|
)
|
||||||
yaml.dump(permutation, temp_file)
|
yaml.dump(permutation, temp_file)
|
||||||
temp_file.close()
|
temp_file.close()
|
||||||
yield temp_file.name
|
yield temp_file.name, is_group
|
||||||
|
|
||||||
|
|
||||||
def launch_training(
|
def launch_training(
|
||||||
@@ -97,6 +107,7 @@ def launch_training(
|
|||||||
cloud: str | None,
|
cloud: str | None,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
launcher_args: list[str] | None = None,
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training with the given configuration."""
|
"""Execute training with the given configuration."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -105,11 +116,14 @@ def launch_training(
|
|||||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||||
elif launcher:
|
elif launcher:
|
||||||
if launcher == "accelerate":
|
if launcher == "accelerate":
|
||||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
_launch_accelerate_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||||
elif launcher == "torchrun":
|
elif launcher == "torchrun":
|
||||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
_launch_torchrun_training(cfg_file, kwargs, launcher_args, use_exec)
|
||||||
elif launcher == "python":
|
elif launcher == "python":
|
||||||
_launch_python_training(cfg_file, kwargs)
|
_launch_python_training(cfg_file, kwargs)
|
||||||
|
elif launcher is None:
|
||||||
|
# handle ray train launch
|
||||||
|
_launch_python_training(cfg_file, kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _launch_cloud_training(
|
def _launch_cloud_training(
|
||||||
@@ -136,7 +150,10 @@ def _launch_cloud_training(
|
|||||||
|
|
||||||
|
|
||||||
def _launch_accelerate_training(
|
def _launch_accelerate_training(
|
||||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
cfg_file: str,
|
||||||
|
kwargs: dict,
|
||||||
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via accelerate launcher."""
|
"""Execute training via accelerate launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -161,11 +178,20 @@ def _launch_accelerate_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
if use_exec:
|
||||||
|
# make sure to flush stdout and stderr before replacing the process
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
|
||||||
|
|
||||||
def _launch_torchrun_training(
|
def _launch_torchrun_training(
|
||||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
cfg_file: str,
|
||||||
|
kwargs: dict,
|
||||||
|
launcher_args: list[str] | None = None,
|
||||||
|
use_exec: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute training via torchrun launcher."""
|
"""Execute training via torchrun launcher."""
|
||||||
launcher_args = launcher_args or []
|
launcher_args = launcher_args or []
|
||||||
@@ -178,7 +204,13 @@ def _launch_torchrun_training(
|
|||||||
base_cmd.append(cfg_file)
|
base_cmd.append(cfg_file)
|
||||||
|
|
||||||
cmd = build_command(base_cmd, kwargs)
|
cmd = build_command(base_cmd, kwargs)
|
||||||
subprocess.run(cmd, check=True) # nosec B603
|
if use_exec:
|
||||||
|
# make sure to flush stdout and stderr before replacing the process
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stderr.flush()
|
||||||
|
os.execvpe(cmd[0], cmd, os.environ) # nosec B606
|
||||||
|
else:
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
|
||||||
|
|
||||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||||
|
|||||||
@@ -2,12 +2,10 @@
|
|||||||
CLI to start the vllm server for online RL
|
CLI to start the vllm server for online RL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import trl
|
|
||||||
from trl.scripts.vllm_serve import ScriptArguments
|
from trl.scripts.vllm_serve import ScriptArguments
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
@@ -42,13 +40,17 @@ def do_vllm_serve(
|
|||||||
|
|
||||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||||
|
tensor_parallel_size = 1
|
||||||
|
data_parallel_size = 1
|
||||||
|
|
||||||
tensor_parallel_size = (
|
if cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size:
|
||||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
tensor_parallel_size = (
|
||||||
)
|
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||||
data_parallel_size = (
|
)
|
||||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
if cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size:
|
||||||
)
|
data_parallel_size = (
|
||||||
|
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||||
|
)
|
||||||
host = cli_args.get("host") or cfg.vllm.host
|
host = cli_args.get("host") or cfg.vllm.host
|
||||||
port = cli_args.get("port") or cfg.vllm.port
|
port = cli_args.get("port") or cfg.vllm.port
|
||||||
gpu_memory_utilization = (
|
gpu_memory_utilization = (
|
||||||
@@ -81,63 +83,3 @@ def do_vllm_serve(
|
|||||||
enable_reasoning=enable_reasoning,
|
enable_reasoning=enable_reasoning,
|
||||||
)
|
)
|
||||||
vllm_serve_main(vllm_script_args)
|
vllm_serve_main(vllm_script_args)
|
||||||
|
|
||||||
|
|
||||||
def patch_vllm_worker():
|
|
||||||
from multiprocessing.connection import Connection
|
|
||||||
|
|
||||||
from vllm import LLM
|
|
||||||
|
|
||||||
def llm_worker(
|
|
||||||
script_args: AxolotlScriptArguments,
|
|
||||||
data_parallel_rank: int,
|
|
||||||
master_port: int,
|
|
||||||
connection: Connection,
|
|
||||||
) -> None:
|
|
||||||
# Set required environment variables for DP to work with vLLM
|
|
||||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
|
||||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
|
||||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
|
||||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model=script_args.model,
|
|
||||||
revision=script_args.revision,
|
|
||||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
|
||||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
|
||||||
enforce_eager=script_args.enforce_eager,
|
|
||||||
dtype=script_args.dtype,
|
|
||||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
|
||||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
|
||||||
# This is particularly useful here because we generate completions from the same prompts.
|
|
||||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
|
||||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
|
||||||
max_model_len=script_args.max_model_len,
|
|
||||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
|
||||||
enable_reasoning=script_args.enable_reasoning,
|
|
||||||
reasoning_parser=script_args.reasoning_parser,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send ready signal to parent process
|
|
||||||
connection.send({"status": "ready"})
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Wait for commands from the parent process
|
|
||||||
try:
|
|
||||||
command = connection.recv()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
llm.collective_rpc(method="close_communicator")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Handle commands
|
|
||||||
if command["type"] in ["call", "fire_and_forget"]:
|
|
||||||
method_name = command["method"]
|
|
||||||
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
|
||||||
method = getattr(llm, method_name)
|
|
||||||
result = method(*args, **kwargs)
|
|
||||||
if command["type"] == "call":
|
|
||||||
connection.send(result)
|
|
||||||
elif command["type"] == "shutdown":
|
|
||||||
break
|
|
||||||
|
|
||||||
trl.scripts.vllm_serve.llm_worker = llm_worker
|
|
||||||
|
|||||||
@@ -13,4 +13,5 @@ MOE_ARCH_BLOCK = {
|
|||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
|
"gpt_oss": "GptOssDecoderLayer",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,12 +24,10 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import PartialState
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
from transformers.training_args import OptimizerNames
|
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||||
@@ -40,6 +38,7 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||||
|
from axolotl.utils.distributed import build_parallelism_config
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -267,27 +266,24 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
optimizer_cls = MuonOptimizerFactory
|
optimizer_cls = MuonOptimizerFactory
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "dion":
|
||||||
|
from axolotl.contribs.mit.dion import ( # pylint: disable=no-name-in-module
|
||||||
|
DionOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_cls = DionOptimizerFactory
|
||||||
|
optimizer_kwargs["dion_lr"] = training_args_kwargs["dion_learning_rate"]
|
||||||
|
optimizer_kwargs["dion_mu"] = training_args_kwargs["dion_momentum"]
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
_, device_mesh = build_parallelism_config(self.cfg)
|
||||||
|
if device_mesh is not None:
|
||||||
|
optimizer_kwargs["device_mesh"] = device_mesh
|
||||||
elif self.cfg.optimizer == "optimi_adamw":
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
optimizer_kwargs["foreach"] = False
|
optimizer_kwargs["foreach"] = False
|
||||||
optimizer_cls = AdamW
|
optimizer_cls = AdamW
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
|
||||||
# TODO remove 20250401
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
|
||||||
|
|
||||||
optimizer_cls = AdamW4bit
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
|
||||||
)
|
|
||||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
|
||||||
|
|
||||||
optimizer_cls = AdamW8bit
|
|
||||||
optimizer_kwargs.update(adam_kwargs)
|
|
||||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
@@ -433,30 +429,12 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||||
partial_state = PartialState()
|
|
||||||
has_pc_attr = (
|
|
||||||
hasattr(partial_state, "parallelism_config")
|
|
||||||
and partial_state.parallelism_config
|
|
||||||
)
|
|
||||||
has_pc_key = (
|
|
||||||
"parallelism_config"
|
|
||||||
in partial_state._shared_state # pylint: disable=protected-access
|
|
||||||
and partial_state._shared_state[ # pylint: disable=protected-access
|
|
||||||
"parallelism_config"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
use_configured_state = has_pc_attr or has_pc_key
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
use_configured_state = self.cfg.accelerator_config.pop(
|
|
||||||
"use_configured_state", use_configured_state
|
|
||||||
)
|
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
**self.cfg.accelerator_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||||
use_configured_state=use_configured_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
@@ -516,10 +494,20 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
"include_tokens_per_second",
|
"include_tokens_per_second",
|
||||||
"weight_decay",
|
"weight_decay",
|
||||||
"seed",
|
"seed",
|
||||||
|
"dion_momentum",
|
||||||
|
"dion_rank_fraction",
|
||||||
|
"dion_rank_multiple_of",
|
||||||
]:
|
]:
|
||||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
|
|
||||||
|
arg_map = {
|
||||||
|
"dion_learning_rate": "dion_lr",
|
||||||
|
}
|
||||||
|
for kwarg, cfg_arg in arg_map.items():
|
||||||
|
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:
|
||||||
|
training_args_kwargs[kwarg] = getattr(self.cfg, cfg_arg)
|
||||||
|
|
||||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||||
training_args_kwargs["average_tokens_across_devices"] = False
|
training_args_kwargs["average_tokens_across_devices"] = False
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import transformers
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
|
Trainer,
|
||||||
)
|
)
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
@@ -43,6 +44,7 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -136,6 +138,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlRewardTrainer
|
return AxolotlRewardTrainer
|
||||||
if self.cfg.process_reward_model:
|
if self.cfg.process_reward_model:
|
||||||
return AxolotlPRMTrainer
|
return AxolotlPRMTrainer
|
||||||
|
|
||||||
|
if self.cfg.trainer_cls:
|
||||||
|
# override the trainer cls
|
||||||
|
try:
|
||||||
|
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||||
|
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||||
|
return trainer_cls
|
||||||
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -350,7 +364,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||||
self.cfg.sequence_len / multiple
|
self.cfg.sequence_len / multiple
|
||||||
)
|
)
|
||||||
else:
|
elif self.cfg.pad_to_sequence_len is None:
|
||||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||||
@@ -372,10 +386,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
)
|
)
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "processing_class" in sig.parameters:
|
if "processing_class" in sig.parameters or issubclass(trainer_cls, Trainer):
|
||||||
trainer_kwargs["processing_class"] = self.tokenizer
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
elif "tokenizer" in sig.parameters:
|
elif "tokenizer" in sig.parameters:
|
||||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
if (
|
if (
|
||||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||||
and self.cfg.datasets is not None
|
and self.cfg.datasets is not None
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import ensure_dtype
|
from axolotl.loaders.utils import ensure_dtype
|
||||||
from axolotl.utils.callbacks.qat import QATCallback
|
from axolotl.utils.callbacks.qat import QATCallback
|
||||||
|
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
|
if self.cfg.trainer_cls:
|
||||||
|
# override the trainer cls
|
||||||
|
try:
|
||||||
|
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||||
|
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||||
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
return trainer_cls, trainer_cls_args
|
return trainer_cls, trainer_cls_args
|
||||||
|
|
||||||
def _build_training_arguments(self, total_num_steps):
|
def _build_training_arguments(self, total_num_steps):
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
|
|
||||||
from .base import AxolotlTrainer
|
from .base import AxolotlTrainer
|
||||||
from .dpo.trainer import AxolotlDPOTrainer
|
from .dpo.trainer import AxolotlDPOTrainer
|
||||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
|
||||||
from .mamba import AxolotlMambaTrainer
|
from .mamba import AxolotlMambaTrainer
|
||||||
from .trl import (
|
from .trl import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
|
|||||||
@@ -10,8 +10,11 @@ from functools import partial, wraps
|
|||||||
from typing import Any, Callable, Literal, Optional
|
from typing import Any, Callable, Literal, Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate.state import AcceleratorState
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from peft import PeftModel
|
||||||
from torch.utils.data import (
|
from torch.utils.data import (
|
||||||
BatchSampler,
|
BatchSampler,
|
||||||
DataLoader,
|
DataLoader,
|
||||||
@@ -19,8 +22,10 @@ from torch.utils.data import (
|
|||||||
Sampler,
|
Sampler,
|
||||||
SequentialSampler,
|
SequentialSampler,
|
||||||
)
|
)
|
||||||
from transformers import Trainer
|
from transformers import PreTrainedModel, Trainer
|
||||||
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -77,7 +82,9 @@ class AxolotlTrainer(
|
|||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(
|
||||||
|
lambda: defaultdict(lambda: {"values": [], "reduction": "mean"})
|
||||||
|
)
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
@@ -515,7 +522,18 @@ class AxolotlTrainer(
|
|||||||
|
|
||||||
@wraps(Trainer.create_accelerator_and_postprocess)
|
@wraps(Trainer.create_accelerator_and_postprocess)
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
res = super().create_accelerator_and_postprocess()
|
# cleanup the PartialState states so Accelerate automatically configures everything from the env vars
|
||||||
|
accelerator_config = self.args.accelerator_config.to_dict()
|
||||||
|
use_configured_state = accelerator_config.get("use_configured_state", False)
|
||||||
|
if not use_configured_state:
|
||||||
|
AcceleratorState._reset_state( # pylint: disable=protected-access
|
||||||
|
reset_partial_state=True
|
||||||
|
)
|
||||||
|
|
||||||
|
super().create_accelerator_and_postprocess()
|
||||||
|
|
||||||
|
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
|
||||||
|
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if (
|
if (
|
||||||
@@ -524,8 +542,6 @@ class AxolotlTrainer(
|
|||||||
):
|
):
|
||||||
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def additional_accelerator_args(
|
def additional_accelerator_args(
|
||||||
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
self, fp8: bool = False, enable_fsdp_float8_all_gather: bool = False, **kwargs
|
||||||
@@ -559,18 +575,35 @@ class AxolotlTrainer(
|
|||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
# Add averaged stored metrics to logs
|
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
# Add reduced stored metrics to logs
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
for key, metric_data in self._stored_metrics[train_eval].items():
|
||||||
|
values = torch.tensor(metric_data["values"])
|
||||||
|
reduction_type = metric_data["reduction"]
|
||||||
|
|
||||||
|
if reduction_type == "mean":
|
||||||
|
logs[key] = values.mean().item()
|
||||||
|
elif reduction_type == "min":
|
||||||
|
logs[key] = values.min().item()
|
||||||
|
elif reduction_type == "max":
|
||||||
|
logs[key] = values.max().item()
|
||||||
|
elif reduction_type == "sum":
|
||||||
|
logs[key] = values.sum().item()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Metric reduction must be one of [mean, min, max, sum]"
|
||||||
|
)
|
||||||
|
|
||||||
|
logs[key] = round(logs[key], 4)
|
||||||
|
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
# Add memory usage
|
# Add memory usage
|
||||||
try:
|
try:
|
||||||
active, allocated, reserved = get_gpu_memory_usage()
|
active, allocated, reserved = get_gpu_memory_usage()
|
||||||
logs["memory/max_memory_active"] = active
|
logs["memory/max_mem_active(gib)"] = round(active, 2)
|
||||||
logs["memory/max_memory_allocated"] = allocated
|
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
|
||||||
logs["memory/device_memory_reserved"] = reserved
|
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
|
||||||
except (ValueError, FileNotFoundError):
|
except (ValueError, TypeError, FileNotFoundError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
@@ -578,10 +611,27 @@ class AxolotlTrainer(
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self,
|
||||||
|
metrics: dict[str, float] | dict[str, tuple[int | float, str]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
reduction: Literal["mean", "min", "max", "sum"] = "mean",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store metrics with specified reduction type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: Dictionary of metric names to values, or metric names to (value,
|
||||||
|
reduction_type) tuples.
|
||||||
|
train_eval: Whether this is for training or evaluation.
|
||||||
|
"""
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
if isinstance(value, tuple):
|
||||||
|
metric_value, metric_reduction = value
|
||||||
|
else:
|
||||||
|
metric_value, metric_reduction = value, reduction
|
||||||
|
|
||||||
|
self._stored_metrics[train_eval][key]["values"].append(metric_value)
|
||||||
|
self._stored_metrics[train_eval][key]["reduction"] = metric_reduction
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, **kwargs):
|
def _save_checkpoint(self, model, trial, **kwargs):
|
||||||
# make sure the checkpoint dir exists, since trainer is flakey
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
@@ -590,3 +640,64 @@ class AxolotlTrainer(
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||||
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
|
# If we are executing this function, we are the process zero, so we don't check for that.
|
||||||
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||||
|
supported_classes = (
|
||||||
|
(PreTrainedModel,)
|
||||||
|
if not is_peft_available()
|
||||||
|
else (PreTrainedModel, PeftModel)
|
||||||
|
)
|
||||||
|
# Save a trained model and configuration using `save_pretrained()`.
|
||||||
|
# They can then be reloaded using `from_pretrained()`
|
||||||
|
if not isinstance(self.model, supported_classes):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.model.state_dict()
|
||||||
|
if isinstance(
|
||||||
|
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||||
|
supported_classes,
|
||||||
|
):
|
||||||
|
self.accelerator.unwrap_model(
|
||||||
|
self.model, keep_torch_compile=False
|
||||||
|
).save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=state_dict,
|
||||||
|
safe_serialization=self.args.save_safetensors,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
|
||||||
|
)
|
||||||
|
if self.args.save_safetensors:
|
||||||
|
safetensors.torch.save_file(
|
||||||
|
state_dict,
|
||||||
|
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
|
||||||
|
metadata={"format": "pt"},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||||
|
else:
|
||||||
|
self.model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=state_dict,
|
||||||
|
safe_serialization=self.args.save_safetensors,
|
||||||
|
is_main_process=self.accelerator.is_main_process,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.processing_class is not None:
|
||||||
|
self.processing_class.save_pretrained(output_dir)
|
||||||
|
elif (
|
||||||
|
self.data_collator is not None
|
||||||
|
and hasattr(self.data_collator, "tokenizer")
|
||||||
|
and self.data_collator.tokenizer is not None
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||||
|
)
|
||||||
|
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||||
|
# Good practice: save your training arguments together with the trained model
|
||||||
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
Mixin for correctly saving fsdp
|
Mixin for correctly saving fsdp
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from accelerate import PartialState
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|
||||||
|
|
||||||
@@ -18,3 +19,15 @@ class DistributedParallelMixin(Trainer):
|
|||||||
):
|
):
|
||||||
state_dict = self.accelerator.get_state_dict(self.model)
|
state_dict = self.accelerator.get_state_dict(self.model)
|
||||||
super()._save(output_dir, state_dict=state_dict)
|
super()._save(output_dir, state_dict=state_dict)
|
||||||
|
|
||||||
|
def create_accelerator_and_postprocess(self):
|
||||||
|
super().create_accelerator_and_postprocess()
|
||||||
|
if (
|
||||||
|
self.accelerator.distributed_type == "FSDP"
|
||||||
|
and self.accelerator.state.fsdp_plugin is None
|
||||||
|
):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
# handle Context Parallelism without FSDP
|
||||||
|
self.accelerator.state.distributed_type = "MULTI_GPU"
|
||||||
|
self.accelerator.state._shared_state["distributed_type"] = "MULTI_GPU"
|
||||||
|
PartialState().distributed_type = "MULTI_GPU"
|
||||||
|
|||||||
@@ -243,3 +243,18 @@ class AxolotlTrainingMixins:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# end of multi-modal section
|
# end of multi-modal section
|
||||||
|
|
||||||
|
dion_learning_rate: float | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The learning rate for Dion"},
|
||||||
|
)
|
||||||
|
dion_momentum: float | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The momentum for Dion"},
|
||||||
|
)
|
||||||
|
dion_rank_fraction: float | None = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
dion_rank_multiple_of: int | None = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|||||||
@@ -26,9 +26,11 @@ import traceback
|
|||||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||||
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
|
from torch import nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from transformers import PreTrainedModel, Trainer
|
from transformers import PreTrainedModel, Trainer
|
||||||
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -74,8 +76,8 @@ class BasePlugin:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initializes the BasePlugin."""
|
"""Initializes the BasePlugin."""
|
||||||
|
|
||||||
def register(self, cfg: DictDefault): # pylint: disable=unused-argument
|
def register(self, cfg: dict): # pylint: disable=unused-argument
|
||||||
"""Registers the plugin with the given configuration.
|
"""Registers the plugin with the given configuration as an unparsed dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: The configuration for the plugin.
|
cfg: The configuration for the plugin.
|
||||||
@@ -145,7 +147,7 @@ class BasePlugin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def get_trainer_cls(self, cfg: DictDefault) -> Trainer | None:
|
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
|
||||||
"""Returns a custom class for the trainer.
|
"""Returns a custom class for the trainer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -641,3 +643,24 @@ class BaseOptimizerFactory:
|
|||||||
self, opt_model, training_args, **optimizer_kwargs
|
self, opt_model, training_args, **optimizer_kwargs
|
||||||
) -> Optimizer | None:
|
) -> Optimizer | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# duplicated from transformers
|
||||||
|
def get_decay_parameter_names(self, model) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all parameter names that weight decay will be applied to.
|
||||||
|
|
||||||
|
This function filters out parameters in two ways:
|
||||||
|
1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
|
||||||
|
2. By parameter name patterns (containing 'bias', or variation of 'norm')
|
||||||
|
"""
|
||||||
|
forbidden_name_patterns = [
|
||||||
|
r"bias",
|
||||||
|
r"layernorm",
|
||||||
|
r"rmsnorm",
|
||||||
|
r"(?:^|\.)norm(?:$|\.)",
|
||||||
|
r"_norm(?:$|\.)",
|
||||||
|
]
|
||||||
|
decay_parameters = get_parameter_names(
|
||||||
|
model, [nn.LayerNorm], forbidden_name_patterns
|
||||||
|
)
|
||||||
|
return decay_parameters
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@@ -31,6 +31,7 @@ plugins:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
|
- arcee
|
||||||
- cohere
|
- cohere
|
||||||
- cohere2
|
- cohere2
|
||||||
- gemma
|
- gemma
|
||||||
@@ -41,13 +42,17 @@ plugins:
|
|||||||
- gemma3n_text
|
- gemma3n_text
|
||||||
- glm
|
- glm
|
||||||
- glm4
|
- glm4
|
||||||
|
- gpt_oss
|
||||||
- granite
|
- granite
|
||||||
- granitemoe
|
- granitemoe
|
||||||
|
- hunyuan_v1_dense
|
||||||
|
- hunyuan_v1_moe
|
||||||
- llama
|
- llama
|
||||||
- llama4
|
- llama4
|
||||||
- llama4_text
|
- llama4_text
|
||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
|
- mixtral
|
||||||
- mllama
|
- mllama
|
||||||
- phi
|
- phi
|
||||||
- phi3
|
- phi3
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@cbd58e0"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
164
src/axolotl/integrations/diffusion/README.md
Normal file
164
src/axolotl/integrations/diffusion/README.md
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
# Diffusion LM Training Plugin for Axolotl
|
||||||
|
|
||||||
|
This plugin enables diffusion language model training using the LLaDA (Large Language
|
||||||
|
And Diffusion Assistant) approach within the Axolotl framework.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LLaDA is a diffusion-based approach to language model training that uses:
|
||||||
|
- **Random token masking** during training instead of next-token prediction
|
||||||
|
- **Bidirectional attention** to allow the model to see the full context
|
||||||
|
- **Importance weighting** based on masking probabilities for stable training
|
||||||
|
|
||||||
|
This approach can lead to more robust language models with better understanding of
|
||||||
|
bidirectional context.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The plugin is included with Axolotl. To use it, simply add the plugin configuration to
|
||||||
|
your training config.
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
### Basic Configuration
|
||||||
|
|
||||||
|
Add the following to your Axolotl configuration YAML:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Enable diffusion LM training plugin
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.diffusion.DiffusionPlugin
|
||||||
|
|
||||||
|
# Diffusion-specific configuration
|
||||||
|
noise_schedule: linear # or "cosine"
|
||||||
|
min_mask_ratio: 0.1
|
||||||
|
max_mask_ratio: 0.9
|
||||||
|
num_diffusion_steps: 128
|
||||||
|
eps: 1e-3
|
||||||
|
importance_weighting: true
|
||||||
|
mask_token_id: 128002
|
||||||
|
|
||||||
|
# Sample generation (optional)
|
||||||
|
generate_samples: true
|
||||||
|
generation_interval: 100
|
||||||
|
num_generation_samples: 3
|
||||||
|
generation_steps: 128
|
||||||
|
generation_temperature: 0.0
|
||||||
|
generation_max_length: 100
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
model_type: llama
|
||||||
|
|
||||||
|
# Standard Axolotl configuration
|
||||||
|
datasets:
|
||||||
|
- path: your_dataset
|
||||||
|
...
|
||||||
|
|
||||||
|
# Other config
|
||||||
|
sequence_len: 1024
|
||||||
|
micro_batch_size: 8
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
learning_rate: 3e-4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
Currently supported base model types:
|
||||||
|
- **Llama** (meta-llama/Llama-*, etc.) - Uses `LlamaForDiffusionLM`
|
||||||
|
- **Mistral** (mistralai/Mistral-*, etc.) - Uses `MistralForDiffusionLM`
|
||||||
|
|
||||||
|
The plugin automatically creates custom model classes that inherit from the base model
|
||||||
|
while adding diffusion training capabilities. This provides full compatibility with
|
||||||
|
HuggingFace's ecosystem for saving, loading, and inference.
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
### Custom Model Architecture
|
||||||
|
|
||||||
|
The plugin creates custom model classes (`LlamaForDiffusionLM`, `MistralForDiffusionLM`) that inherit from
|
||||||
|
standard HuggingFace models. During training, these models:
|
||||||
|
|
||||||
|
1. **Apply forward diffusion process**: Randomly mask tokens based on sampled timesteps
|
||||||
|
2. **Use bidirectional attention**: Override causal attention with full bidirectional attention
|
||||||
|
3. **Compute diffusion loss**: Calculate loss only on masked tokens with optional importance weighting
|
||||||
|
|
||||||
|
### Random Masking
|
||||||
|
During training, tokens are randomly masked based on a sampled timestep:
|
||||||
|
- Sample timestep `t` uniformly from [0, 1]
|
||||||
|
- Calculate masking probability: `p = (1 - eps) * t + eps`
|
||||||
|
- Randomly mask tokens with probability `p`
|
||||||
|
|
||||||
|
### Bidirectional Attention
|
||||||
|
The models override causal attention with bidirectional attention:
|
||||||
|
- Creates 4D attention masks allowing all-to-all attention
|
||||||
|
- Maintains proper padding and sample packing masks
|
||||||
|
- Compatible with standard HuggingFace attention implementations
|
||||||
|
|
||||||
|
### Diffusion Loss
|
||||||
|
|
||||||
|
Loss is computed only on masked tokens with (optional) importance weighting:
|
||||||
|
|
||||||
|
```python
|
||||||
|
loss = sum(cross_entropy(pred, target) / p_mask) / total_tokens
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Loading and Saving
|
||||||
|
|
||||||
|
The custom models work seamlessly with HuggingFace's AutoModel system:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoModel, AutoConfig
|
||||||
|
|
||||||
|
# Load a diffusion model
|
||||||
|
model = AutoModel.from_pretrained("path/to/diffusion/model", trust_remote_code=True)
|
||||||
|
|
||||||
|
# Save a diffusion model
|
||||||
|
model.save_pretrained("path/to/save/diffusion/model")
|
||||||
|
```
|
||||||
|
|
||||||
|
During inference, the models behave like standard causal language models.
|
||||||
|
|
||||||
|
## Sample Generation
|
||||||
|
|
||||||
|
When `generate_samples: true`, the plugin generates samples during training:
|
||||||
|
|
||||||
|
```
|
||||||
|
Sample 1:
|
||||||
|
Original (45 tokens): The quick brown fox jumps over the lazy dog...
|
||||||
|
Masked (18/45 tokens, 40.0%): The [MASK] [MASK] fox [MASK] over [MASK] lazy [MASK]...
|
||||||
|
Generated: The quick brown fox jumps over the lazy dog...
|
||||||
|
```
|
||||||
|
|
||||||
|
Samples are logged to console and wandb (if enabled).
|
||||||
|
|
||||||
|
## Metrics and Monitoring
|
||||||
|
|
||||||
|
The plugin adds several metrics to track diffusion training:
|
||||||
|
|
||||||
|
- `train/loss`: Weighted diffusion loss
|
||||||
|
- `train/accuracy`: Accuracy on masked tokens
|
||||||
|
- `train/mask_ratio`: Average fraction of tokens masked
|
||||||
|
- `train/num_masked_tokens`: Number of tokens masked
|
||||||
|
- `train/avg_p_mask`: Average masking probability
|
||||||
|
- `train/ce_loss`: Unweighted cross-entropy loss
|
||||||
|
- `train/importance_weight_avg`: Average importance weight
|
||||||
|
|
||||||
|
## Benefits of Custom Model Approach
|
||||||
|
|
||||||
|
✅ **Type Safety**: Full IDE support and type checking
|
||||||
|
✅ **HuggingFace Integration**: Works with AutoModel, Hub, pipelines
|
||||||
|
✅ **Maintainability**: Clean architecture, no monkey patching
|
||||||
|
✅ **Ecosystem Compatibility**: Standard save/load, PEFT support
|
||||||
|
✅ **Testing**: Easier to test and debug
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- **Model Support**: Currently limited to Llama and Mistral architectures
|
||||||
|
- **Flash Attention**: Not yet optimized for flash attention
|
||||||
|
- **Inference Speed**: Bidirectional attention is slower than causal for generation
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [LLaDA Paper](https://arxiv.org/abs/2404.10406)
|
||||||
|
- [Axolotl Documentation](https://docs.axolotl.ai/)
|
||||||
26
src/axolotl/integrations/diffusion/__init__.py
Normal file
26
src/axolotl/integrations/diffusion/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Diffusion LM training plugin init."""
|
||||||
|
|
||||||
|
from transformers import AutoConfig, AutoModel
|
||||||
|
|
||||||
|
from .args import DiffusionArgs
|
||||||
|
from .configuration import DiffusionConfig, LlamaForDiffusionConfig, MistralForDiffusionConfig
|
||||||
|
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
|
||||||
|
from .plugin import DiffusionPlugin
|
||||||
|
|
||||||
|
# Register custom configurations
|
||||||
|
AutoConfig.register("llama_diffusion", LlamaForDiffusionConfig)
|
||||||
|
AutoConfig.register("mistral_diffusion", MistralForDiffusionConfig)
|
||||||
|
|
||||||
|
# Register custom models
|
||||||
|
AutoModel.register(LlamaForDiffusionConfig, LlamaForDiffusionLM)
|
||||||
|
AutoModel.register(MistralForDiffusionConfig, MistralForDiffusionLM)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DiffusionArgs",
|
||||||
|
"DiffusionPlugin",
|
||||||
|
"DiffusionConfig",
|
||||||
|
"LlamaForDiffusionConfig",
|
||||||
|
"MistralForDiffusionConfig",
|
||||||
|
"LlamaForDiffusionLM",
|
||||||
|
"MistralForDiffusionLM",
|
||||||
|
]
|
||||||
70
src/axolotl/integrations/diffusion/args.py
Normal file
70
src/axolotl/integrations/diffusion/args.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""Config args for diffusion LM training."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionArgs(BaseModel):
|
||||||
|
"""Arguments for diffusion LM training plugin."""
|
||||||
|
|
||||||
|
# Noise schedule config
|
||||||
|
noise_schedule: Literal["linear", "cosine"] = Field(
|
||||||
|
default="linear", description="Type of noise schedule for diffusion training"
|
||||||
|
)
|
||||||
|
min_mask_ratio: float = Field(
|
||||||
|
default=0.1,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum masking ratio for diffusion noise schedule",
|
||||||
|
)
|
||||||
|
max_mask_ratio: float = Field(
|
||||||
|
default=0.9,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Maximum masking ratio for diffusion noise schedule",
|
||||||
|
)
|
||||||
|
num_diffusion_steps: int = Field(
|
||||||
|
default=128, ge=1, description="Number of diffusion timesteps"
|
||||||
|
)
|
||||||
|
eps: float = Field(
|
||||||
|
default=1e-3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Epsilon value for minimum masking probability in forward process",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training config
|
||||||
|
importance_weighting: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Apply importance weighting to loss based on masking probability",
|
||||||
|
)
|
||||||
|
mask_token_id: int = Field(
|
||||||
|
default=128002,
|
||||||
|
description=(
|
||||||
|
"Token ID to use for masking. Default is 128002 "
|
||||||
|
"(<|reserved_special_token_0|> for Llama 3.2)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample generation config
|
||||||
|
generate_samples: bool = Field(
|
||||||
|
default=True, description="Enable sample generation during training"
|
||||||
|
)
|
||||||
|
generation_interval: int = Field(
|
||||||
|
default=100, ge=1, description="Generate samples every N steps"
|
||||||
|
)
|
||||||
|
num_generation_samples: int = Field(
|
||||||
|
default=3, ge=1, description="Number of samples to generate each time"
|
||||||
|
)
|
||||||
|
generation_steps: int = Field(
|
||||||
|
default=128, ge=1, description="Number of diffusion steps for generation"
|
||||||
|
)
|
||||||
|
generation_temperature: float = Field(
|
||||||
|
default=0.0,
|
||||||
|
ge=0.0,
|
||||||
|
description="Temperature for generation sampling (0.0 = deterministic)",
|
||||||
|
)
|
||||||
|
generation_max_length: int = Field(
|
||||||
|
default=100, ge=1, description="Maximum sequence length for generation"
|
||||||
|
)
|
||||||
116
src/axolotl/integrations/diffusion/callbacks.py
Normal file
116
src/axolotl/integrations/diffusion/callbacks.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""Callbacks for diffusion training."""
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
|
||||||
|
from transformers.training_args import TrainingArguments
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .generation import generate_samples
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionGenerationCallback(TrainerCallback):
|
||||||
|
"""Callback for generating samples during diffusion training."""
|
||||||
|
|
||||||
|
def __init__(self, trainer):
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def on_step_end(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Generate samples at specified intervals."""
|
||||||
|
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
|
||||||
|
|
||||||
|
if (
|
||||||
|
state.global_step > 0
|
||||||
|
and state.global_step % config.get('generation_interval', 100) == 0
|
||||||
|
):
|
||||||
|
# Use eval dataloader if available, otherwise use train dataloader
|
||||||
|
if (
|
||||||
|
hasattr(self.trainer, "eval_dataset")
|
||||||
|
and self.trainer.eval_dataset is not None
|
||||||
|
):
|
||||||
|
dataloader = self.trainer.get_eval_dataloader()
|
||||||
|
else:
|
||||||
|
dataloader = self.trainer.get_train_dataloader()
|
||||||
|
|
||||||
|
# Generate samples
|
||||||
|
samples = generate_samples(
|
||||||
|
model=self.trainer.model,
|
||||||
|
tokenizer=self.trainer.tokenizer,
|
||||||
|
dataloader=dataloader,
|
||||||
|
num_generation_samples=config.get('num_generation_samples', 3),
|
||||||
|
max_length=config.get('generation_max_length', 256),
|
||||||
|
num_diffusion_steps=config.get('generation_steps', 10),
|
||||||
|
temperature=config.get('generation_temperature', 1.0),
|
||||||
|
mask_token_id=config.get('mask_token_id', 32000),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log samples
|
||||||
|
self._log_samples(samples, state.global_step)
|
||||||
|
|
||||||
|
def _log_samples(self, samples: list, step: int):
|
||||||
|
"""Log generated samples."""
|
||||||
|
if not samples:
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.info("=" * 60)
|
||||||
|
LOG.info("GENERATED SAMPLES")
|
||||||
|
LOG.info("=" * 60)
|
||||||
|
|
||||||
|
for i, sample_data in enumerate(samples, 1):
|
||||||
|
original = sample_data["original"]
|
||||||
|
masked = sample_data["masked"]
|
||||||
|
generated = sample_data["generated"]
|
||||||
|
mask_ratio = sample_data["mask_ratio"]
|
||||||
|
masked_tokens = sample_data["masked_tokens"]
|
||||||
|
total_tokens = sample_data["total_tokens"]
|
||||||
|
|
||||||
|
LOG.info(f"\nSample {i}:")
|
||||||
|
LOG.info(f"\tOriginal ({total_tokens} tokens): {original}")
|
||||||
|
LOG.info(
|
||||||
|
f"\tMasked ({masked_tokens}/{total_tokens} tokens, "
|
||||||
|
f"{mask_ratio:.1%}): {masked}"
|
||||||
|
)
|
||||||
|
LOG.info(f"\tGenerated: {generated}")
|
||||||
|
|
||||||
|
LOG.info("=" * 60)
|
||||||
|
|
||||||
|
config = getattr(self.trainer, 'diffusion_config', self.trainer.args)
|
||||||
|
if config.get('use_wandb', False) and self.trainer.state.is_world_process_zero:
|
||||||
|
if wandb.run is not None:
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"generated_samples": wandb.Table(
|
||||||
|
columns=[
|
||||||
|
"step",
|
||||||
|
"original",
|
||||||
|
"masked",
|
||||||
|
"generated",
|
||||||
|
"mask_ratio",
|
||||||
|
"masked_tokens",
|
||||||
|
"total_tokens",
|
||||||
|
],
|
||||||
|
data=[
|
||||||
|
[
|
||||||
|
step,
|
||||||
|
sample["original"],
|
||||||
|
sample["masked"],
|
||||||
|
sample["generated"],
|
||||||
|
f"{sample['mask_ratio']:.1%}",
|
||||||
|
sample["masked_tokens"],
|
||||||
|
sample["total_tokens"],
|
||||||
|
]
|
||||||
|
for sample in samples
|
||||||
|
],
|
||||||
|
)
|
||||||
|
},
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
71
src/axolotl/integrations/diffusion/configuration.py
Normal file
71
src/axolotl/integrations/diffusion/configuration.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""Configuration classes for diffusion language models."""
|
||||||
|
|
||||||
|
from transformers import LlamaConfig, MistralConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForDiffusionConfig(LlamaConfig):
|
||||||
|
"""Configuration class for Llama models with diffusion training."""
|
||||||
|
|
||||||
|
model_type = "llama_diffusion"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask_token_id: int = 32000,
|
||||||
|
eps: float = 1e-3,
|
||||||
|
importance_weighting: bool = False,
|
||||||
|
sample_packing: bool = False,
|
||||||
|
min_mask_ratio: float = 0.0,
|
||||||
|
max_mask_ratio: float = 1.0,
|
||||||
|
noise_schedule: str = "linear",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Diffusion-specific parameters
|
||||||
|
self.mask_token_id = mask_token_id
|
||||||
|
self.eps = eps
|
||||||
|
self.importance_weighting = importance_weighting
|
||||||
|
self.sample_packing = sample_packing
|
||||||
|
self.min_mask_ratio = min_mask_ratio
|
||||||
|
self.max_mask_ratio = max_mask_ratio
|
||||||
|
self.noise_schedule = noise_schedule
|
||||||
|
|
||||||
|
|
||||||
|
class MistralForDiffusionConfig(MistralConfig):
|
||||||
|
"""Configuration class for Mistral models with diffusion training."""
|
||||||
|
|
||||||
|
model_type = "mistral_diffusion"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mask_token_id: int = 32000,
|
||||||
|
eps: float = 1e-3,
|
||||||
|
importance_weighting: bool = False,
|
||||||
|
sample_packing: bool = False,
|
||||||
|
min_mask_ratio: float = 0.0,
|
||||||
|
max_mask_ratio: float = 1.0,
|
||||||
|
noise_schedule: str = "linear",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Diffusion-specific parameters
|
||||||
|
self.mask_token_id = mask_token_id
|
||||||
|
self.eps = eps
|
||||||
|
self.importance_weighting = importance_weighting
|
||||||
|
self.sample_packing = sample_packing
|
||||||
|
self.min_mask_ratio = min_mask_ratio
|
||||||
|
self.max_mask_ratio = max_mask_ratio
|
||||||
|
self.noise_schedule = noise_schedule
|
||||||
|
|
||||||
|
|
||||||
|
# Keep the base class for backward compatibility but mark as deprecated
|
||||||
|
class DiffusionConfig(LlamaForDiffusionConfig):
|
||||||
|
"""
|
||||||
|
Deprecated: Use LlamaForDiffusionConfig or MistralForDiffusionConfig instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "diffusion"
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
269
src/axolotl/integrations/diffusion/generation.py
Normal file
269
src/axolotl/integrations/diffusion/generation.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""Sample generation utilities for diffusion training."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_samples(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
tokenizer: Any,
|
||||||
|
dataloader: Optional[Any] = None,
|
||||||
|
num_generation_samples: int = 3,
|
||||||
|
max_length: int = 100,
|
||||||
|
num_diffusion_steps: int = 128,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
mask_token_id: int = 32000,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Generate text samples using the diffusion model by randomly masking sequences from
|
||||||
|
the given dataset and running the reverse diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The wrapped or unwrapped model
|
||||||
|
tokenizer: Tokenizer for encoding/decoding
|
||||||
|
dataloader: Validation dataloader (for sampling sequences)
|
||||||
|
num_generation_samples: Number of samples to generate
|
||||||
|
max_length: Maximum length of sequences to use
|
||||||
|
num_diffusion_steps: Number of diffusion steps for generation
|
||||||
|
temperature: Temperature for sampling (0.0 = deterministic)
|
||||||
|
mask_token_id: Token ID used for masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries with original text, masked text, and generated text
|
||||||
|
"""
|
||||||
|
if dataloader is None:
|
||||||
|
logger.warning("No validation dataloader provided, cannot generate samples")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get the actual model (unwrap if needed)
|
||||||
|
unwrapped_model = model.module if hasattr(model, "module") else model
|
||||||
|
unwrapped_model.eval()
|
||||||
|
generations = []
|
||||||
|
|
||||||
|
# Sample sequences from validation dataset
|
||||||
|
sampled_sequences = _sample_sequences_from_dataloader(
|
||||||
|
dataloader, num_generation_samples, max_length, unwrapped_model.device
|
||||||
|
)
|
||||||
|
logger.info(f"Sampled {len(sampled_sequences)} sequences from validation dataset")
|
||||||
|
|
||||||
|
# Generate samples using reverse diffusion process
|
||||||
|
with torch.no_grad():
|
||||||
|
for original_sequence in sampled_sequences:
|
||||||
|
generation_result = _generate(
|
||||||
|
unwrapped_model,
|
||||||
|
tokenizer,
|
||||||
|
original_sequence,
|
||||||
|
num_diffusion_steps,
|
||||||
|
temperature,
|
||||||
|
mask_token_id,
|
||||||
|
)
|
||||||
|
generations.append(generation_result)
|
||||||
|
|
||||||
|
unwrapped_model.train()
|
||||||
|
return generations
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_sequences_from_dataloader(
|
||||||
|
dataloader: Any, num_samples: int, max_length: int, device: torch.device
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Sample sequences from validation dataloader."""
|
||||||
|
sampled_sequences = []
|
||||||
|
sample_count = 0
|
||||||
|
|
||||||
|
# Add randomness by skipping a random number of batches
|
||||||
|
skip_batches = torch.randint(0, 6, (1,)).item()
|
||||||
|
batch_count = 0
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
# Skip some batches for variety
|
||||||
|
if batch_count < skip_batches:
|
||||||
|
batch_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sample_count >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch_count += 1
|
||||||
|
input_ids = batch["input_ids"]
|
||||||
|
attention_mask = batch.get("attention_mask")
|
||||||
|
|
||||||
|
# Randomly sample from sequences in this batch
|
||||||
|
batch_indices = torch.randperm(input_ids.size(0)).tolist()
|
||||||
|
|
||||||
|
for i in batch_indices:
|
||||||
|
if sample_count >= num_samples:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get actual sequence length (non-padded)
|
||||||
|
if attention_mask is not None:
|
||||||
|
seq_len = attention_mask[i].sum().item()
|
||||||
|
else:
|
||||||
|
seq_len = input_ids.size(1)
|
||||||
|
|
||||||
|
# Limit sequence length to max_length
|
||||||
|
actual_length = min(seq_len, max_length)
|
||||||
|
if actual_length < 10: # Skip very short sequences
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract the sequence
|
||||||
|
sequence = input_ids[i][:actual_length].unsqueeze(0).to(device)
|
||||||
|
sampled_sequences.append(sequence)
|
||||||
|
sample_count += 1
|
||||||
|
|
||||||
|
return sampled_sequences
|
||||||
|
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
tokenizer: Any,
|
||||||
|
original_sequence: torch.Tensor,
|
||||||
|
num_diffusion_steps: int,
|
||||||
|
temperature: float,
|
||||||
|
mask_token_id: int,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate a single sample using reverse diffusion."""
|
||||||
|
# Get original text for comparison
|
||||||
|
original_text = tokenizer.decode(
|
||||||
|
original_sequence[0].cpu(), skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply custom masking with random ratio (10% to 70%)
|
||||||
|
total_tokens = original_sequence.size(1)
|
||||||
|
min_ratio, max_ratio = 0.1, 0.7
|
||||||
|
target_mask_ratio = torch.rand(1).item() * (max_ratio - min_ratio) + min_ratio
|
||||||
|
target_masked_tokens = int(total_tokens * target_mask_ratio)
|
||||||
|
|
||||||
|
# Create random mask indices
|
||||||
|
mask_positions = torch.randperm(total_tokens)[:target_masked_tokens]
|
||||||
|
masked_indices = torch.zeros(
|
||||||
|
1, total_tokens, dtype=torch.bool, device=original_sequence.device
|
||||||
|
)
|
||||||
|
masked_indices[0, mask_positions] = True
|
||||||
|
|
||||||
|
# Create masked sequence
|
||||||
|
masked_sequence = original_sequence.clone()
|
||||||
|
masked_sequence[masked_indices] = mask_token_id
|
||||||
|
|
||||||
|
# Calculate actual mask ratio
|
||||||
|
masked_tokens = masked_indices.sum().item()
|
||||||
|
mask_ratio = masked_tokens / total_tokens
|
||||||
|
|
||||||
|
# Get masked text for comparison
|
||||||
|
masked_text = tokenizer.decode(masked_sequence[0].cpu(), skip_special_tokens=False)
|
||||||
|
# Clean up mask token representation
|
||||||
|
masked_text = _clean_masked_text(masked_text, tokenizer, mask_token_id)
|
||||||
|
|
||||||
|
# Run reverse diffusion process
|
||||||
|
sequence = masked_sequence.clone()
|
||||||
|
for step in range(num_diffusion_steps):
|
||||||
|
sequence = _diffusion_step(
|
||||||
|
model, sequence, step, num_diffusion_steps, temperature, mask_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get final generated text
|
||||||
|
generated_text = tokenizer.decode(sequence[0].cpu(), skip_special_tokens=True)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"original": original_text,
|
||||||
|
"masked": masked_text,
|
||||||
|
"generated": generated_text,
|
||||||
|
"mask_ratio": mask_ratio,
|
||||||
|
"masked_tokens": masked_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"formatted": (
|
||||||
|
f"Original: '{original_text}' → Masked: '{masked_text}' "
|
||||||
|
f"({mask_ratio:.1%}) → Generated: '{generated_text}'"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_masked_text(masked_text: str, tokenizer: Any, mask_token_id: int) -> str:
|
||||||
|
"""Clean up masked text for display."""
|
||||||
|
mask_token_repr = tokenizer.decode([mask_token_id], skip_special_tokens=False)
|
||||||
|
cleaned = masked_text.replace(mask_token_repr, "[MASK]")
|
||||||
|
|
||||||
|
if hasattr(tokenizer, "special_tokens_map"):
|
||||||
|
for token_value in tokenizer.special_tokens_map.values():
|
||||||
|
if token_value and isinstance(token_value, str):
|
||||||
|
cleaned = cleaned.replace(token_value, "")
|
||||||
|
|
||||||
|
cleaned = " ".join(cleaned.split()).strip()
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def _diffusion_step(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
sequence: torch.Tensor,
|
||||||
|
step: int,
|
||||||
|
num_diffusion_steps: int,
|
||||||
|
temperature: float,
|
||||||
|
mask_token_id: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Perform a single diffusion step with remasking."""
|
||||||
|
# Only process if there are masked tokens remaining
|
||||||
|
current_mask = sequence == mask_token_id
|
||||||
|
if not current_mask.any():
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
# Create bidirectional attention mask for diffusion
|
||||||
|
batch_size, seq_len = sequence.shape
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=sequence.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = model(input_ids=sequence, attention_mask=attention_mask)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# Only sample at currently masked positions
|
||||||
|
if current_mask.any():
|
||||||
|
masked_logits = logits[current_mask]
|
||||||
|
|
||||||
|
# Apply temperature scaling
|
||||||
|
if temperature > 0:
|
||||||
|
scaled_logits = masked_logits / temperature
|
||||||
|
else:
|
||||||
|
scaled_logits = masked_logits
|
||||||
|
|
||||||
|
# Suppress mask token in outputs
|
||||||
|
scaled_logits[:, mask_token_id] = -float("inf")
|
||||||
|
|
||||||
|
# Sample predictions
|
||||||
|
if temperature > 0:
|
||||||
|
# Add Gumbel noise for sampling
|
||||||
|
gumbel_noise = -torch.log(
|
||||||
|
-torch.log(torch.rand_like(scaled_logits, dtype=torch.float32))
|
||||||
|
)
|
||||||
|
gumbel_logits = scaled_logits + gumbel_noise
|
||||||
|
predicted_tokens = torch.argmax(gumbel_logits, dim=-1)
|
||||||
|
else:
|
||||||
|
# Deterministic sampling when temperature is 0
|
||||||
|
predicted_tokens = torch.argmax(scaled_logits, dim=-1)
|
||||||
|
|
||||||
|
# Calculate probabilities for confidence scoring
|
||||||
|
probs = torch.softmax(scaled_logits, dim=-1)
|
||||||
|
predicted_token_probs = probs[range(len(predicted_tokens)), predicted_tokens]
|
||||||
|
|
||||||
|
# Determine how many tokens to unmask this step
|
||||||
|
remaining_masked = current_mask.sum().item()
|
||||||
|
if step == num_diffusion_steps - 1:
|
||||||
|
num_to_unmask = remaining_masked
|
||||||
|
else:
|
||||||
|
unmask_ratio = 1.0 / (num_diffusion_steps - step)
|
||||||
|
num_to_unmask = max(1, int(remaining_masked * unmask_ratio))
|
||||||
|
|
||||||
|
# Select highest confidence predictions to unmask
|
||||||
|
if num_to_unmask >= remaining_masked:
|
||||||
|
sequence[current_mask] = predicted_tokens
|
||||||
|
else:
|
||||||
|
_, top_indices = predicted_token_probs.topk(num_to_unmask)
|
||||||
|
mask_positions = torch.where(current_mask)[1]
|
||||||
|
positions_to_unmask = mask_positions[top_indices]
|
||||||
|
sequence[0, positions_to_unmask] = predicted_tokens[top_indices]
|
||||||
|
|
||||||
|
return sequence
|
||||||
426
src/axolotl/integrations/diffusion/models.py
Normal file
426
src/axolotl/integrations/diffusion/models.py
Normal file
@@ -0,0 +1,426 @@
|
|||||||
|
"""Custom model classes for diffusion language models."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import LlamaForCausalLM, MistralForCausalLM
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionModelMixin:
|
||||||
|
"""Mixin class providing diffusion functionality to language models."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._special_token_ids = None
|
||||||
|
|
||||||
|
def _cache_special_token_ids(self, tokenizer=None):
|
||||||
|
"""Cache special token IDs to avoid repeated tokenizer access."""
|
||||||
|
if tokenizer is None:
|
||||||
|
self._special_token_ids = set()
|
||||||
|
return
|
||||||
|
|
||||||
|
special_tokens = set()
|
||||||
|
|
||||||
|
if hasattr(tokenizer, "bos_token_id") and tokenizer.bos_token_id is not None:
|
||||||
|
special_tokens.add(tokenizer.bos_token_id)
|
||||||
|
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
|
||||||
|
special_tokens.add(tokenizer.eos_token_id)
|
||||||
|
if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
|
||||||
|
special_tokens.add(tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
self._special_token_ids = special_tokens
|
||||||
|
|
||||||
|
def _forward_process(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
labels: torch.Tensor | None = None,
|
||||||
|
eps: float = 1e-3,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward noising process. A timestep is sampled along the process, and tokens are
|
||||||
|
masked with probability determined by the configured noise schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token ids [batch_size, seq_len].
|
||||||
|
attention_mask: Attention mask [batch_size, seq_len].
|
||||||
|
labels: Labels for SFT training [batch_size, seq_len].
|
||||||
|
eps: Small epsilon value for minimum masking probability.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
noisy_batch: Input with some tokens masked.
|
||||||
|
masked_indices: Boolean mask indicating which tokens were masked.
|
||||||
|
p_mask: Masking probabilities for each token [batch_size, seq_len].
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Sample random timesteps for each sample in batch
|
||||||
|
t = torch.rand(batch_size, device=device)
|
||||||
|
|
||||||
|
# Calculate masking probability with epsilon
|
||||||
|
p_mask = (1 - eps) * t + eps # [batch_size]
|
||||||
|
p_mask = p_mask[:, None].repeat(1, seq_len) # [batch_size, seq_len]
|
||||||
|
|
||||||
|
# Don't mask padding tokens if attention_mask is provided
|
||||||
|
if attention_mask is not None:
|
||||||
|
valid_mask = attention_mask.bool()
|
||||||
|
p_mask = p_mask * valid_mask.float()
|
||||||
|
|
||||||
|
# Create mask to exclude special tokens
|
||||||
|
special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool)
|
||||||
|
if self._special_token_ids:
|
||||||
|
for token_id in self._special_token_ids:
|
||||||
|
special_token_mask |= input_ids == token_id
|
||||||
|
|
||||||
|
# Create random mask based on p_mask
|
||||||
|
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask
|
||||||
|
masked_indices = masked_indices & ~special_token_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
masked_indices = masked_indices & attention_mask.bool()
|
||||||
|
|
||||||
|
# For SFT data, only mask answer tokens
|
||||||
|
if labels is not None:
|
||||||
|
answer_mask = labels != -100
|
||||||
|
masked_indices = masked_indices & answer_mask
|
||||||
|
|
||||||
|
# Create masked input
|
||||||
|
mask_token_id = self.config.mask_token_id
|
||||||
|
noisy_batch = torch.where(masked_indices, mask_token_id, input_ids)
|
||||||
|
|
||||||
|
return noisy_batch, masked_indices, p_mask
|
||||||
|
|
||||||
|
def _create_bidirectional_attention_mask(
|
||||||
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Create bidirectional attention mask to override default causal masking. Handles
|
||||||
|
sample-packed sequences where different samples are identified by different
|
||||||
|
attention mask values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token ids [batch_size, seq_len].
|
||||||
|
attention_mask: Attention mask [batch_size, seq_len]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bidirectional_mask: 4D attention mask [batch_size, 1, seq_len, seq_len].
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
if attention_mask is None or not self.config.sample_packing:
|
||||||
|
return torch.ones(
|
||||||
|
batch_size, 1, seq_len, seq_len, dtype=torch.bool, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create attention mask by comparing sample IDs element-wise
|
||||||
|
mask_i = attention_mask.unsqueeze(2) # [batch_size, seq_len, 1]
|
||||||
|
mask_j = attention_mask.unsqueeze(1) # [batch_size, 1, seq_len]
|
||||||
|
|
||||||
|
# Tokens can attend to each other if they have the same non-zero sample ID
|
||||||
|
bidirectional_mask = (mask_i == mask_j) & (mask_i > 0)
|
||||||
|
|
||||||
|
# Add head dimension: [batch_size, 1, seq_len, seq_len]
|
||||||
|
bidirectional_mask = bidirectional_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
return bidirectional_mask
|
||||||
|
|
||||||
|
def _compute_diffusion_loss(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
labels: torch.Tensor | None = None,
|
||||||
|
logits: torch.Tensor | None = None,
|
||||||
|
masked_indices: torch.Tensor | None = None,
|
||||||
|
p_mask: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute diffusion loss given logits and masking information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Ground truth token ids [batch_size, seq_len].
|
||||||
|
attention_mask: Attention mask [batch_size, seq_len].
|
||||||
|
labels: Labels for SFT training [batch_size, seq_len].
|
||||||
|
logits: Model logits [batch_size, seq_len, vocab_size].
|
||||||
|
masked_indices: Boolean mask indicating which tokens were masked.
|
||||||
|
p_mask: Masking probabilities for each token [batch_size, seq_len].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
loss: Cross-entropy loss.
|
||||||
|
"""
|
||||||
|
if masked_indices.sum() > 0:
|
||||||
|
valid_indices = torch.where(masked_indices)
|
||||||
|
batch_indices, seq_indices = valid_indices
|
||||||
|
|
||||||
|
masked_logits = logits[batch_indices, seq_indices]
|
||||||
|
masked_targets = input_ids[batch_indices, seq_indices]
|
||||||
|
masked_p_mask = p_mask[batch_indices, seq_indices]
|
||||||
|
|
||||||
|
# Compute cross-entropy loss without reduction
|
||||||
|
token_loss = F.cross_entropy(
|
||||||
|
masked_logits.float(), masked_targets, reduction="none"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.importance_weighting:
|
||||||
|
masked_p_mask = masked_p_mask.float()
|
||||||
|
weighted_loss = token_loss / masked_p_mask
|
||||||
|
else:
|
||||||
|
weighted_loss = token_loss
|
||||||
|
|
||||||
|
# Final loss: sum weighted losses, normalize
|
||||||
|
if labels is not None:
|
||||||
|
# For SFT data: normalize by answer length per sample
|
||||||
|
answer_mask = labels != -100
|
||||||
|
answer_lengths = answer_mask.sum(dim=1).float() # [batch_size]
|
||||||
|
|
||||||
|
# Get batch indices for masked tokens
|
||||||
|
masked_batch_indices = batch_indices
|
||||||
|
|
||||||
|
# Sum losses per sample and divide by answer length
|
||||||
|
loss_per_sample = torch.zeros(
|
||||||
|
input_ids.shape[0], device=input_ids.device
|
||||||
|
)
|
||||||
|
for i in range(input_ids.shape[0]):
|
||||||
|
sample_mask = masked_batch_indices == i
|
||||||
|
if sample_mask.sum() > 0:
|
||||||
|
sample_loss = weighted_loss[sample_mask].sum()
|
||||||
|
loss_per_sample[i] = sample_loss / answer_lengths[i]
|
||||||
|
|
||||||
|
loss = loss_per_sample.mean()
|
||||||
|
else:
|
||||||
|
# Original normalization for non-SFT data
|
||||||
|
loss = weighted_loss.sum() / (input_ids.shape[0] * input_ids.shape[1])
|
||||||
|
else:
|
||||||
|
loss = torch.tensor(0.0, device=input_ids.device, requires_grad=True)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForDiffusionLM(DiffusionModelMixin, LlamaForCausalLM):
|
||||||
|
"""
|
||||||
|
Llama model for diffusion language modeling.
|
||||||
|
|
||||||
|
This model extends LlamaForCausalLM with diffusion training capabilities,
|
||||||
|
including bidirectional attention and forward diffusion process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LlamaForDiffusionConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Initialize diffusion-specific attributes
|
||||||
|
self._special_token_ids = None
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def set_tokenizer(self, tokenizer):
|
||||||
|
"""Set tokenizer for special token handling."""
|
||||||
|
self._cache_special_token_ids(tokenizer)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
"""
|
||||||
|
Forward pass with diffusion training logic.
|
||||||
|
|
||||||
|
During training, applies forward diffusion process and bidirectional attention.
|
||||||
|
During inference, behaves like standard causal language model.
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.training and input_ids is not None:
|
||||||
|
# Apply diffusion process during training
|
||||||
|
original_input_ids = input_ids.clone()
|
||||||
|
|
||||||
|
# Apply forward process to get noisy input
|
||||||
|
noisy_input_ids, masked_indices, p_mask = self._forward_process(
|
||||||
|
input_ids, attention_mask, labels, self.config.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create bidirectional attention mask
|
||||||
|
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
|
||||||
|
input_ids, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward pass with noisy input and bidirectional attention
|
||||||
|
outputs = super().forward(
|
||||||
|
input_ids=noisy_input_ids,
|
||||||
|
attention_mask=bidirectional_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
labels=None, # Don't use standard loss computation
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute diffusion loss
|
||||||
|
loss = self._compute_diffusion_loss(
|
||||||
|
original_input_ids,
|
||||||
|
attention_mask,
|
||||||
|
labels,
|
||||||
|
outputs.logits,
|
||||||
|
masked_indices,
|
||||||
|
p_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_dict:
|
||||||
|
outputs.loss = loss
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
return (loss,) + outputs[1:]
|
||||||
|
else:
|
||||||
|
# Standard forward pass for inference
|
||||||
|
return super().forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MistralForDiffusionLM(DiffusionModelMixin, MistralForCausalLM):
|
||||||
|
"""
|
||||||
|
Mistral model for diffusion language modeling.
|
||||||
|
|
||||||
|
This model extends MistralForCausalLM with diffusion training capabilities,
|
||||||
|
including bidirectional attention and forward diffusion process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = MistralForDiffusionConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Initialize diffusion-specific attributes
|
||||||
|
self._special_token_ids = None
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def set_tokenizer(self, tokenizer):
|
||||||
|
"""Set tokenizer for special token handling."""
|
||||||
|
self._cache_special_token_ids(tokenizer)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
"""
|
||||||
|
Forward pass with diffusion training logic.
|
||||||
|
|
||||||
|
During training, applies forward diffusion process and bidirectional attention.
|
||||||
|
During inference, behaves like standard causal language model.
|
||||||
|
"""
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.training and input_ids is not None:
|
||||||
|
# Apply diffusion process during training
|
||||||
|
original_input_ids = input_ids.clone()
|
||||||
|
|
||||||
|
# Apply forward process to get noisy input
|
||||||
|
noisy_input_ids, masked_indices, p_mask = self._forward_process(
|
||||||
|
input_ids, attention_mask, labels, self.config.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create bidirectional attention mask
|
||||||
|
bidirectional_attention_mask = self._create_bidirectional_attention_mask(
|
||||||
|
input_ids, attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward pass with noisy input and bidirectional attention
|
||||||
|
outputs = super().forward(
|
||||||
|
input_ids=noisy_input_ids,
|
||||||
|
attention_mask=bidirectional_attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
labels=None, # Don't use standard loss computation
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute diffusion loss
|
||||||
|
loss = self._compute_diffusion_loss(
|
||||||
|
original_input_ids,
|
||||||
|
attention_mask,
|
||||||
|
labels,
|
||||||
|
outputs.logits,
|
||||||
|
masked_indices,
|
||||||
|
p_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_dict:
|
||||||
|
outputs.loss = loss
|
||||||
|
return outputs
|
||||||
|
else:
|
||||||
|
return (loss,) + outputs[1:]
|
||||||
|
else:
|
||||||
|
# Standard forward pass for inference
|
||||||
|
return super().forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
labels=labels,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
98
src/axolotl/integrations/diffusion/plugin.py
Normal file
98
src/axolotl/integrations/diffusion/plugin.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""Diffusion LM training plugin for Axolotl."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import AutoConfig, AutoModel, PreTrainedModel
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .callbacks import DiffusionGenerationCallback
|
||||||
|
from .configuration import LlamaForDiffusionConfig, MistralForDiffusionConfig
|
||||||
|
from .models import LlamaForDiffusionLM, MistralForDiffusionLM
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffusionPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for diffusion language model training.
|
||||||
|
|
||||||
|
This plugin enables diffusion-based training using the LLaDA approach, which uses
|
||||||
|
random masking and bidirectional attention to train language models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = None
|
||||||
|
|
||||||
|
def get_input_args(self) -> str:
|
||||||
|
"""Returns the pydantic model for LLaDA plugin arguments."""
|
||||||
|
return "axolotl.integrations.diffusion.DiffusionArgs"
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg: DictDefault):
|
||||||
|
"""Configure model loading to use diffusion model classes."""
|
||||||
|
# Map base model types to diffusion equivalents
|
||||||
|
base_model_type = cfg.get("model_type")
|
||||||
|
|
||||||
|
if base_model_type == "llama":
|
||||||
|
# Create diffusion config from base config
|
||||||
|
diffusion_config = LlamaForDiffusionConfig(
|
||||||
|
mask_token_id=getattr(cfg, "mask_token_id", 32000),
|
||||||
|
eps=getattr(cfg, "eps", 1e-3),
|
||||||
|
importance_weighting=getattr(cfg, "importance_weighting", False),
|
||||||
|
sample_packing=getattr(cfg, "sample_packing", False),
|
||||||
|
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
|
||||||
|
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
|
||||||
|
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override model type for loading
|
||||||
|
cfg.model_type = "llama_diffusion"
|
||||||
|
|
||||||
|
elif base_model_type == "mistral":
|
||||||
|
# Create diffusion config from base config
|
||||||
|
diffusion_config = MistralForDiffusionConfig(
|
||||||
|
mask_token_id=getattr(cfg, "mask_token_id", 32000),
|
||||||
|
eps=getattr(cfg, "eps", 1e-3),
|
||||||
|
importance_weighting=getattr(cfg, "importance_weighting", False),
|
||||||
|
sample_packing=getattr(cfg, "sample_packing", False),
|
||||||
|
min_mask_ratio=getattr(cfg, "min_mask_ratio", 0.0),
|
||||||
|
max_mask_ratio=getattr(cfg, "max_mask_ratio", 1.0),
|
||||||
|
noise_schedule=getattr(cfg, "noise_schedule", "linear"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override model type for loading
|
||||||
|
cfg.model_type = "mistral_diffusion"
|
||||||
|
else:
|
||||||
|
LOG.warning(f"Diffusion plugin not implemented for model type: {base_model_type}")
|
||||||
|
|
||||||
|
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||||
|
"""Configure model after loading."""
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
# Set tokenizer on diffusion models for special token handling
|
||||||
|
if hasattr(model, "set_tokenizer"):
|
||||||
|
# Get tokenizer from cfg if available
|
||||||
|
tokenizer = getattr(cfg, "tokenizer", None)
|
||||||
|
if tokenizer is not None:
|
||||||
|
model.set_tokenizer(tokenizer)
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg: DictDefault, trainer: "Trainer"):
|
||||||
|
"""Add diffusion-specific callbacks after trainer creation."""
|
||||||
|
callbacks = []
|
||||||
|
|
||||||
|
# Store diffusion config on trainer for callbacks
|
||||||
|
trainer.diffusion_config = cfg
|
||||||
|
|
||||||
|
# Add generation callback if enabled
|
||||||
|
if cfg.get("generate_samples", False):
|
||||||
|
generation_callback = DiffusionGenerationCallback(trainer)
|
||||||
|
callbacks.append(generation_callback)
|
||||||
|
|
||||||
|
return callbacks
|
||||||
@@ -284,12 +284,12 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
def _tokenize_single_prompt(self, prompt):
|
def _tokenize_single_prompt(self, prompt):
|
||||||
logprobs = prompt.pop(self.logprobs_field)
|
target_token_ids = prompt.get("target_token_ids", None)
|
||||||
target_token_ids = prompt.pop("target_token_ids")
|
|
||||||
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||||
tokenized_prompt[self.logprobs_field] = logprobs
|
|
||||||
tokenized_prompt["target_token_ids"] = target_token_ids
|
if target_token_ids is not None:
|
||||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
tokenized_prompt["target_token_ids"] = target_token_ids
|
||||||
|
|
||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import Callable
|
|||||||
import torch
|
import torch
|
||||||
from bitsandbytes.functional import QuantState
|
from bitsandbytes.functional import QuantState
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.distributed.tensor import DTensor
|
||||||
|
|
||||||
from .geglu import geglu_backward, geglu_forward
|
from .geglu import geglu_backward, geglu_forward
|
||||||
from .quantize import dequantize
|
from .quantize import dequantize
|
||||||
@@ -25,6 +26,7 @@ def get_lora_parameters(
|
|||||||
proj: nn.Module,
|
proj: nn.Module,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
|
torch.Tensor | None,
|
||||||
QuantState | None,
|
QuantState | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
@@ -37,39 +39,54 @@ def get_lora_parameters(
|
|||||||
proj: The projection module to extract parameters from.
|
proj: The projection module to extract parameters from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple containing the base weight matrix, quantization state, LoRA A matrix,
|
A tuple containing the base weights, quantization state, LoRA A and B weights,
|
||||||
LoRA B matrix, and scaling factor. States and matrices may be None if not
|
scaling factor, and base layer bias. Quant state, weights, and bias may be
|
||||||
available.
|
`None` if not available.
|
||||||
"""
|
"""
|
||||||
# For DPO or disabled adapters
|
# For DPO or disabled adapters
|
||||||
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
|
||||||
W = base_layer.weight
|
W = base_layer.weight
|
||||||
|
b = base_layer.bias
|
||||||
|
|
||||||
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
||||||
quant_state = getattr(W, "quant_state", None)
|
quant_state = getattr(W, "quant_state", None)
|
||||||
return W, quant_state, None, None, None
|
return W, b, quant_state, None, None, None
|
||||||
|
|
||||||
|
quant_state = getattr(W, "quant_state", None)
|
||||||
|
|
||||||
active_adapter = (
|
active_adapter = (
|
||||||
proj.active_adapters[0]
|
proj.active_adapters[0]
|
||||||
if hasattr(proj, "active_adapters")
|
if hasattr(proj, "active_adapters")
|
||||||
else proj.active_adapter
|
else proj.active_adapter
|
||||||
)
|
)
|
||||||
A = proj.lora_A[active_adapter].weight
|
|
||||||
B = proj.lora_B[active_adapter].weight
|
linear_A = proj.lora_A[active_adapter]
|
||||||
|
linear_B = proj.lora_B[active_adapter]
|
||||||
|
|
||||||
|
# This manual unsharding is needed for FSDP2 + LoRA kernels compatibility.
|
||||||
|
# We fuse linear layers + LoRA adapters calculations into a single
|
||||||
|
# torch.autograd.Function, bypassing the registered unshard / reshard behavior.
|
||||||
|
# Note that we don't apply resharding later in this module (it gets messy quickly),
|
||||||
|
# but LoRA parameters are generally small enough that this is not an issue.
|
||||||
|
if isinstance(linear_A.weight, DTensor):
|
||||||
|
linear_A.unshard()
|
||||||
|
linear_B.unshard()
|
||||||
|
|
||||||
|
A = linear_A.weight
|
||||||
|
B = linear_B.weight
|
||||||
s = proj.scaling[active_adapter]
|
s = proj.scaling[active_adapter]
|
||||||
|
|
||||||
quant_state = getattr(W, "quant_state", None)
|
return W, b, quant_state, A, B, s
|
||||||
|
|
||||||
return W, quant_state, A, B, s
|
|
||||||
|
|
||||||
|
|
||||||
def matmul_lora(
|
def matmul_lora(
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
W_quant: QuantState,
|
b: torch.Tensor | None,
|
||||||
A: torch.Tensor,
|
W_quant: QuantState | None,
|
||||||
B: torch.Tensor,
|
A: torch.Tensor | None,
|
||||||
s: float,
|
B: torch.Tensor | None,
|
||||||
|
s: float | None,
|
||||||
out: torch.Tensor | None = None,
|
out: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -90,20 +107,22 @@ def matmul_lora(
|
|||||||
dtype = X.dtype
|
dtype = X.dtype
|
||||||
W = dequantize(W.t(), W_quant)
|
W = dequantize(W.t(), W_quant)
|
||||||
|
|
||||||
|
reshape = False
|
||||||
if X.dim() == 3:
|
if X.dim() == 3:
|
||||||
batch, seq_len, _ = X.shape
|
batch, seq_len, _ = X.shape
|
||||||
X = X.view(-1, X.shape[-1])
|
X = X.view(-1, X.shape[-1])
|
||||||
reshape = True
|
reshape = True
|
||||||
else:
|
|
||||||
reshape = False
|
|
||||||
|
|
||||||
out = torch.matmul(X, W, out=out)
|
out = torch.matmul(X, W, out=out)
|
||||||
if W_quant is not None:
|
if W_quant is not None:
|
||||||
del W
|
del W
|
||||||
|
|
||||||
if A is not None:
|
if A is not None:
|
||||||
A, B = A.t(), B.t()
|
A, B = A.t().to(dtype), B.t().to(dtype) # type: ignore[union-attr]
|
||||||
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
out += s * X @ A @ B
|
||||||
|
|
||||||
|
if b is not None:
|
||||||
|
out += b
|
||||||
|
|
||||||
return out.view(batch, seq_len, -1) if reshape else out
|
return out.view(batch, seq_len, -1) if reshape else out
|
||||||
|
|
||||||
@@ -117,17 +136,20 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
ctx,
|
ctx,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
gate_weight: torch.Tensor,
|
gate_weight: torch.Tensor,
|
||||||
gate_quant: object | None,
|
gate_bias: torch.Tensor | None,
|
||||||
|
gate_quant: QuantState | None,
|
||||||
gate_A: torch.Tensor | None,
|
gate_A: torch.Tensor | None,
|
||||||
gate_B: torch.Tensor | None,
|
gate_B: torch.Tensor | None,
|
||||||
gate_scale: float,
|
gate_scale: float,
|
||||||
up_weight: torch.Tensor,
|
up_weight: torch.Tensor,
|
||||||
up_quant: object | None,
|
up_bias: torch.Tensor | None,
|
||||||
|
up_quant: QuantState | None,
|
||||||
up_A: torch.Tensor | None,
|
up_A: torch.Tensor | None,
|
||||||
up_B: torch.Tensor | None,
|
up_B: torch.Tensor | None,
|
||||||
up_scale: float,
|
up_scale: float,
|
||||||
down_weight: torch.Tensor,
|
down_weight: torch.Tensor,
|
||||||
down_quant: object | None,
|
down_bias: torch.Tensor | None,
|
||||||
|
down_quant: QuantState | None,
|
||||||
down_A: torch.Tensor | None,
|
down_A: torch.Tensor | None,
|
||||||
down_B: torch.Tensor | None,
|
down_B: torch.Tensor | None,
|
||||||
down_scale: float,
|
down_scale: float,
|
||||||
@@ -142,20 +164,22 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
ctx: Autograd context
|
ctx: Autograd context
|
||||||
X: Input features
|
X: Input features
|
||||||
gate_weight: Gate projection weight
|
gate_weight: Gate projection weight
|
||||||
|
gate_bias: Gate projection bias
|
||||||
gate_quant: Gate quantization state
|
gate_quant: Gate quantization state
|
||||||
gate_A: Gate LoRA A matrix
|
gate_A: Gate LoRA A matrix
|
||||||
gate_B: Gate LoRA B matrix
|
gate_B: Gate LoRA B matrix
|
||||||
gate_scale: Gate LoRA scale
|
gate_scale: Gate LoRA scale
|
||||||
up_weight: Up-projection weight
|
up_weight: Up projection weight
|
||||||
up_quant: Up-projection quantization state
|
up_quant: Up projection quantization state
|
||||||
up_A: Up-projection LoRA A matrix
|
up_A: Up projection LoRA A matrix
|
||||||
up_B: Up-projection LoRA B matrix
|
up_B: Up projection LoRA B matrix
|
||||||
up_scale: Up-projection LoRA scale
|
up_scale: Up projection LoRA scale
|
||||||
down_weight: Down-projection weight
|
down_weight: Down projection weight
|
||||||
down_quant: Down-projection quantization state
|
down_bias: Down projection bias
|
||||||
down_A: Down-projection LoRA A matrix
|
down_quant: Down projection quantization state
|
||||||
down_B: Down-projection LoRA B matrix
|
down_A: Down projection LoRA A matrix
|
||||||
down_scale: Down-projection LoRA scale
|
down_B: Down projection LoRA B matrix
|
||||||
|
down_scale: Down projection LoRA scale
|
||||||
activation_fn: Forward activation function
|
activation_fn: Forward activation function
|
||||||
activation_fn_backward: Backward activation function
|
activation_fn_backward: Backward activation function
|
||||||
inplace: Whether to perform operations in-place
|
inplace: Whether to perform operations in-place
|
||||||
@@ -164,15 +188,17 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
Output transformed by multi-layer perceptron and activation function
|
Output transformed by multi-layer perceptron and activation function
|
||||||
"""
|
"""
|
||||||
# Compute projections
|
# Compute projections
|
||||||
gate = matmul_lora(X, gate_weight, gate_quant, gate_A, gate_B, gate_scale)
|
gate = matmul_lora(
|
||||||
up = matmul_lora(X, up_weight, up_quant, up_A, up_B, up_scale)
|
X, gate_weight, gate_bias, gate_quant, gate_A, gate_B, gate_scale
|
||||||
|
)
|
||||||
|
up = matmul_lora(X, up_weight, up_bias, up_quant, up_A, up_B, up_scale)
|
||||||
|
|
||||||
# Activation
|
# Activation
|
||||||
hidden = activation_fn(gate, up)
|
hidden = activation_fn(gate, up)
|
||||||
|
|
||||||
# Down projection
|
# Down projection
|
||||||
output = matmul_lora(
|
output = matmul_lora(
|
||||||
hidden, down_weight, down_quant, down_A, down_B, down_scale
|
hidden, down_weight, down_bias, down_quant, down_A, down_B, down_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save for backward
|
# Save for backward
|
||||||
@@ -195,22 +221,26 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
Performs backward pass computation for LoRA MLP.
|
Performs backward pass computation for LoRA MLP.
|
||||||
@@ -222,7 +252,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple containing gradients for all inputs from forward pass:
|
Tuple containing gradients for all inputs from forward pass:
|
||||||
- Input gradient tensor (or `None`)
|
- Input gradient tensor (or `None`)
|
||||||
- `None` for weights/quantization states
|
- `None` for weights/biases/quantization states
|
||||||
- LoRA A/B matrix gradients (or `None`)
|
- LoRA A/B matrix gradients (or `None`)
|
||||||
- `None` for scaling factors
|
- `None` for scaling factors
|
||||||
- `None` for activation functions and flags
|
- `None` for activation functions and flags
|
||||||
@@ -265,9 +295,10 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
dtype = X.dtype
|
dtype = X.dtype
|
||||||
|
|
||||||
# Down projection
|
# Down projection
|
||||||
DW = matmul_lora(
|
grad_down = matmul_lora(
|
||||||
grad_output,
|
grad_output,
|
||||||
down_weight.t(),
|
down_weight.t(),
|
||||||
|
None,
|
||||||
down_quant,
|
down_quant,
|
||||||
down_B,
|
down_B,
|
||||||
down_A,
|
down_A,
|
||||||
@@ -275,7 +306,7 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Activation backward
|
# Activation backward
|
||||||
h, grad_gate, grad_up = ctx.activation_fn_backward(DW, gate, up)
|
h, grad_gate, grad_up = ctx.activation_fn_backward(grad_down, gate, up)
|
||||||
|
|
||||||
# Initialize and compute LoRA gradients
|
# Initialize and compute LoRA gradients
|
||||||
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
d_down_A = d_down_B = d_up_A = d_up_B = d_gate_A = d_gate_B = None
|
||||||
@@ -315,8 +346,8 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
dX += grad_up @ up_B.to(dtype).t() @ (up_scale * up_A.to(dtype).t())
|
||||||
|
|
||||||
# Gate projection gradients
|
# Gate projection gradients
|
||||||
gate_weight = dequantize(gate_weight.t(), gate_quant)
|
gate_weight = dequantize(gate_weight, gate_quant)
|
||||||
dX += grad_gate @ gate_weight.t()
|
dX += grad_gate @ gate_weight
|
||||||
del gate_weight
|
del gate_weight
|
||||||
|
|
||||||
if gate_A is not None and gate_B is not None:
|
if gate_A is not None and gate_B is not None:
|
||||||
@@ -334,22 +365,26 @@ class LoRA_MLP(torch.autograd.Function):
|
|||||||
dX,
|
dX,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
d_gate_A.t() if d_gate_A is not None else None,
|
d_gate_A.t() if d_gate_A is not None else None,
|
||||||
d_gate_B.t() if d_gate_B is not None else None,
|
d_gate_B.t() if d_gate_B is not None else None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
d_up_A.t() if d_up_A is not None else None,
|
d_up_A.t() if d_up_A is not None else None,
|
||||||
d_up_B.t() if d_up_B is not None else None,
|
d_up_B.t() if d_up_B is not None else None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
d_down_A.t() if d_down_A is not None else None,
|
d_down_A.t() if d_down_A is not None else None,
|
||||||
d_down_B.t() if d_down_B is not None else None,
|
d_down_B.t() if d_down_B is not None else None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -364,23 +399,26 @@ def apply_lora_mlp_swiglu(self, X: torch.Tensor, inplace: bool = True) -> torch.
|
|||||||
Returns:
|
Returns:
|
||||||
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
Output tensor after applying LoRA-adapted MLP with SwiGLU activation
|
||||||
"""
|
"""
|
||||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||||
|
|
||||||
out = LoRA_MLP.apply(
|
out = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
gateW,
|
gateW,
|
||||||
|
gateb,
|
||||||
gateW_quant,
|
gateW_quant,
|
||||||
gateA,
|
gateA,
|
||||||
gateB,
|
gateB,
|
||||||
gateS,
|
gateS,
|
||||||
upW,
|
upW,
|
||||||
|
upb,
|
||||||
upW_quant,
|
upW_quant,
|
||||||
upA,
|
upA,
|
||||||
upB,
|
upB,
|
||||||
upS,
|
upS,
|
||||||
downW,
|
downW,
|
||||||
|
downb,
|
||||||
downW_quant,
|
downW_quant,
|
||||||
downA,
|
downA,
|
||||||
downB,
|
downB,
|
||||||
@@ -404,22 +442,25 @@ def apply_lora_mlp_geglu(self, X: torch.Tensor, inplace: bool = True) -> torch.T
|
|||||||
Returns:
|
Returns:
|
||||||
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
Output tensor after applying LoRA-adapted MLP with GEGLU activation
|
||||||
"""
|
"""
|
||||||
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
gateW, gateb, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
||||||
upW, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
upW, upb, upW_quant, upA, upB, upS = get_lora_parameters(self.up_proj)
|
||||||
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
downW, downb, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
||||||
out = LoRA_MLP.apply(
|
out = LoRA_MLP.apply(
|
||||||
X,
|
X,
|
||||||
gateW,
|
gateW,
|
||||||
|
gateb,
|
||||||
gateW_quant,
|
gateW_quant,
|
||||||
gateA,
|
gateA,
|
||||||
gateB,
|
gateB,
|
||||||
gateS,
|
gateS,
|
||||||
upW,
|
upW,
|
||||||
|
upb,
|
||||||
upW_quant,
|
upW_quant,
|
||||||
upA,
|
upA,
|
||||||
upB,
|
upB,
|
||||||
upS,
|
upS,
|
||||||
downW,
|
downW,
|
||||||
|
downb,
|
||||||
downW_quant,
|
downW_quant,
|
||||||
downA,
|
downA,
|
||||||
downB,
|
downB,
|
||||||
@@ -446,16 +487,19 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
q_weight: torch.Tensor,
|
q_weight: torch.Tensor,
|
||||||
|
q_bias: torch.Tensor | None,
|
||||||
q_quant: QuantState | None,
|
q_quant: QuantState | None,
|
||||||
q_A: torch.Tensor | None,
|
q_A: torch.Tensor | None,
|
||||||
q_B: torch.Tensor | None,
|
q_B: torch.Tensor | None,
|
||||||
q_scale: float,
|
q_scale: float,
|
||||||
k_weight: torch.Tensor,
|
k_weight: torch.Tensor,
|
||||||
|
k_bias: torch.Tensor | None,
|
||||||
k_quant: QuantState | None,
|
k_quant: QuantState | None,
|
||||||
k_A: torch.Tensor | None,
|
k_A: torch.Tensor | None,
|
||||||
k_B: torch.Tensor | None,
|
k_B: torch.Tensor | None,
|
||||||
k_scale: float,
|
k_scale: float,
|
||||||
v_weight: torch.Tensor,
|
v_weight: torch.Tensor,
|
||||||
|
v_bias: torch.Tensor | None,
|
||||||
v_quant: QuantState | None,
|
v_quant: QuantState | None,
|
||||||
v_A: torch.Tensor | None,
|
v_A: torch.Tensor | None,
|
||||||
v_B: torch.Tensor | None,
|
v_B: torch.Tensor | None,
|
||||||
@@ -469,16 +513,19 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
ctx: Autograd context
|
ctx: Autograd context
|
||||||
X: Input tensor
|
X: Input tensor
|
||||||
q_weight: Query projection weight
|
q_weight: Query projection weight
|
||||||
|
q_bias: Query projection bias
|
||||||
q_quant: Query quantization state
|
q_quant: Query quantization state
|
||||||
q_A: Query LoRA A matrix
|
q_A: Query LoRA A matrix
|
||||||
q_B: Query LoRA B matrix
|
q_B: Query LoRA B matrix
|
||||||
q_scale: Query LoRA scale
|
q_scale: Query LoRA scale
|
||||||
k_weight: Key projection weight
|
k_weight: Key projection weight
|
||||||
|
k_bias: Key projection bias
|
||||||
k_quant: Key quantization state
|
k_quant: Key quantization state
|
||||||
k_A: Key LoRA A matrix
|
k_A: Key LoRA A matrix
|
||||||
k_B: Key LoRA B matrix
|
k_B: Key LoRA B matrix
|
||||||
k_scale: Key LoRA scale
|
k_scale: Key LoRA scale
|
||||||
v_weight: Value projection weight
|
v_weight: Value projection weight
|
||||||
|
v_bias: Value projection bias
|
||||||
v_quant: Value quantization state
|
v_quant: Value quantization state
|
||||||
v_A: Value LoRA A matrix
|
v_A: Value LoRA A matrix
|
||||||
v_B: Value LoRA B matrix
|
v_B: Value LoRA B matrix
|
||||||
@@ -488,20 +535,21 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (Query, Key, Value) projection tensors
|
Tuple of (Query, Key, Value) projection tensors
|
||||||
"""
|
"""
|
||||||
Q = matmul_lora(X, q_weight, q_quant, q_A, q_B, q_scale)
|
Q = matmul_lora(X, q_weight, q_bias, q_quant, q_A, q_B, q_scale)
|
||||||
K = matmul_lora(X, k_weight, k_quant, k_A, k_B, k_scale)
|
K = matmul_lora(X, k_weight, k_bias, k_quant, k_A, k_B, k_scale)
|
||||||
V = matmul_lora(X, v_weight, v_quant, v_A, v_B, v_scale)
|
V = matmul_lora(X, v_weight, v_bias, v_quant, v_A, v_B, v_scale)
|
||||||
|
|
||||||
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
ctx.save_for_backward(X, q_A, q_B, k_A, k_B, v_A, v_B)
|
||||||
ctx.scales = (q_scale, k_scale, v_scale)
|
ctx.scales = (q_scale, k_scale, v_scale)
|
||||||
ctx.quants = (q_quant, k_quant, v_quant)
|
ctx.quants = (q_quant, k_quant, v_quant)
|
||||||
ctx.weights = (q_weight, k_weight, v_weight)
|
ctx.weights = (q_weight, k_weight, v_weight)
|
||||||
|
ctx.biases = (q_bias, k_bias, v_bias)
|
||||||
ctx.inplace = inplace
|
ctx.inplace = inplace
|
||||||
|
|
||||||
return Q, K, V
|
return Q, K, V
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch_amp_custom_fwd
|
@torch_amp_custom_bwd
|
||||||
def backward(
|
def backward(
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
q_grad: torch.Tensor,
|
q_grad: torch.Tensor,
|
||||||
@@ -511,16 +559,19 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
torch.Tensor | None,
|
torch.Tensor | None,
|
||||||
None,
|
None,
|
||||||
@@ -608,31 +659,31 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
# Transpose gradients if needed
|
# Transpose gradients if needed
|
||||||
if d_A_q is not None:
|
if d_A_q is not None:
|
||||||
d_A_q = d_A_q.t()
|
d_A_q = d_A_q.t()
|
||||||
if d_B_q is not None:
|
d_B_q = d_B_q.t() # type: ignore[union-attr]
|
||||||
d_B_q = d_B_q.t()
|
|
||||||
if d_A_k is not None:
|
if d_A_k is not None:
|
||||||
d_A_k = d_A_k.t()
|
d_A_k = d_A_k.t()
|
||||||
if d_B_k is not None:
|
d_B_k = d_B_k.t() # type: ignore[union-attr]
|
||||||
d_B_k = d_B_k.t()
|
|
||||||
if d_A_v is not None:
|
if d_A_v is not None:
|
||||||
d_A_v = d_A_v.t()
|
d_A_v = d_A_v.t()
|
||||||
if d_B_v is not None:
|
d_B_v = d_B_v.t() # type: ignore[union-attr]
|
||||||
d_B_v = d_B_v.t()
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
grad_X.view(batch, seq_len, -1),
|
grad_X.view(batch, seq_len, -1),
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
d_A_q,
|
d_A_q,
|
||||||
d_B_q,
|
d_B_q,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
d_A_k,
|
d_A_k,
|
||||||
d_B_k,
|
d_B_k,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
d_A_v,
|
d_A_v,
|
||||||
d_B_v,
|
d_B_v,
|
||||||
None,
|
None,
|
||||||
@@ -653,22 +704,25 @@ def apply_lora_qkv(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (Query, Key, Value) projection tensors
|
Tuple of (Query, Key, Value) projection tensors
|
||||||
"""
|
"""
|
||||||
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
QW, Qb, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
||||||
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
KW, Kb, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
||||||
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
VW, Vb, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
||||||
Q, K, V = LoRA_QKV.apply(
|
Q, K, V = LoRA_QKV.apply(
|
||||||
X,
|
X,
|
||||||
QW,
|
QW,
|
||||||
|
Qb,
|
||||||
QW_quant,
|
QW_quant,
|
||||||
QA,
|
QA,
|
||||||
QB,
|
QB,
|
||||||
QS,
|
QS,
|
||||||
KW,
|
KW,
|
||||||
|
Kb,
|
||||||
KW_quant,
|
KW_quant,
|
||||||
KA,
|
KA,
|
||||||
KB,
|
KB,
|
||||||
KS,
|
KS,
|
||||||
VW,
|
VW,
|
||||||
|
Vb,
|
||||||
VW_quant,
|
VW_quant,
|
||||||
VA,
|
VA,
|
||||||
VB,
|
VB,
|
||||||
@@ -688,10 +742,11 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
W_quant: QuantState | None,
|
W_quant: QuantState | None,
|
||||||
A: torch.Tensor | None,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor | None,
|
B: torch.Tensor,
|
||||||
S: float,
|
s: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for output projection with LoRA.
|
Forward pass for output projection with LoRA.
|
||||||
@@ -700,19 +755,20 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
ctx: Autograd context
|
ctx: Autograd context
|
||||||
X: Input tensor
|
X: Input tensor
|
||||||
W: Output projection weight
|
W: Output projection weight
|
||||||
|
b: Output projection bias
|
||||||
W_quant: Weight quantization state
|
W_quant: Weight quantization state
|
||||||
A: LoRA A matrix
|
A: LoRA A matrix
|
||||||
B: LoRA B matrix
|
B: LoRA B matrix
|
||||||
S: LoRA scaling factor
|
s: LoRA scaling factor
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Output projection tensor
|
Output projection result
|
||||||
"""
|
"""
|
||||||
XW = matmul_lora(X, W, W_quant, A, B, S)
|
XW = matmul_lora(X, W, b, W_quant, A, B, s)
|
||||||
ctx.custom_saved_tensors = (
|
ctx.custom_saved_tensors = (
|
||||||
W,
|
W,
|
||||||
W_quant,
|
W_quant,
|
||||||
S,
|
s,
|
||||||
)
|
)
|
||||||
ctx.save_for_backward(A, B, X)
|
ctx.save_for_backward(A, B, X)
|
||||||
|
|
||||||
@@ -727,8 +783,9 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
torch.Tensor | None,
|
None,
|
||||||
torch.Tensor | None,
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
None,
|
None,
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
@@ -741,7 +798,7 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple containing gradients for all forward inputs
|
Tuple containing gradients for all forward inputs
|
||||||
"""
|
"""
|
||||||
W, W_quant, S = ctx.custom_saved_tensors
|
W, W_quant, s = ctx.custom_saved_tensors
|
||||||
A, B, X = ctx.saved_tensors
|
A, B, X = ctx.saved_tensors
|
||||||
|
|
||||||
batch, seq_len, hd = X.shape
|
batch, seq_len, hd = X.shape
|
||||||
@@ -751,17 +808,19 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
|
|
||||||
# Weight projection
|
# Weight projection
|
||||||
dY_X = X.t() @ dY
|
dY_X = X.t() @ dY
|
||||||
d_A = S * dY_X @ B
|
d_A = s * dY_X @ B
|
||||||
d_B = S * A @ dY_X
|
d_B = s * A @ dY_X
|
||||||
|
|
||||||
# Get derivative for dX
|
# Get derivative for dX
|
||||||
W = dequantize(W.t(), W_quant)
|
W = dequantize(W.t(), W_quant)
|
||||||
dX = dY @ W.t()
|
dX = dY @ W.t()
|
||||||
del W
|
del W
|
||||||
dX += dY @ B.to(dtype) @ (S * A.to(dtype))
|
|
||||||
|
|
||||||
# W, W_quant, A, B, S
|
A, B = A.to(dtype), B.to(dtype)
|
||||||
return dX.view(batch, seq_len, hd), None, None, d_A.t(), d_B.t(), None
|
dX += s * dY @ B @ A
|
||||||
|
|
||||||
|
# W, b, W_quant, A, B, s
|
||||||
|
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -774,7 +833,7 @@ def apply_lora_o(self, X: torch.Tensor) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
Transformed output tensor
|
Transformed output tensor
|
||||||
"""
|
"""
|
||||||
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
OW, Ob, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
||||||
output = LoRA_O.apply(X, OW, OW_quant, OA, OB, OS)
|
output = LoRA_O.apply(X, OW, Ob, OW_quant, OA, OB, OS)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ def load_lora(
|
|||||||
config_only: bool = False,
|
config_only: bool = False,
|
||||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||||
lora_target_modules = cfg.lora_target_modules or []
|
lora_target_modules = cfg.lora_target_modules or []
|
||||||
|
lora_target_parameters = cfg.lora_target_parameters or []
|
||||||
|
|
||||||
if cfg.lora_target_linear:
|
if cfg.lora_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
linear_names = find_all_linear_names(model)
|
||||||
@@ -106,6 +107,7 @@ def load_lora(
|
|||||||
r=cfg.lora_r,
|
r=cfg.lora_r,
|
||||||
lora_alpha=cfg.lora_alpha,
|
lora_alpha=cfg.lora_alpha,
|
||||||
target_modules=lora_target_modules,
|
target_modules=lora_target_modules,
|
||||||
|
target_parameters=lora_target_parameters,
|
||||||
layers_to_transform=cfg.peft_layers_to_transform,
|
layers_to_transform=cfg.peft_layers_to_transform,
|
||||||
layers_pattern=cfg.peft_layers_pattern,
|
layers_pattern=cfg.peft_layers_pattern,
|
||||||
lora_dropout=cfg.lora_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
|
|||||||
@@ -1,26 +1,13 @@
|
|||||||
"""Shared constants for axolotl.loaders module"""
|
"""Shared constants for axolotl.loaders module"""
|
||||||
|
|
||||||
from transformers import (
|
from transformers import AutoModelForImageTextToText
|
||||||
Gemma3ForConditionalGeneration,
|
from transformers.models.auto.modeling_auto import (
|
||||||
Gemma3nForConditionalGeneration,
|
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
||||||
Llama4ForConditionalGeneration,
|
|
||||||
LlavaForConditionalGeneration,
|
|
||||||
Mistral3ForConditionalGeneration,
|
|
||||||
MllamaForConditionalGeneration,
|
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
|
||||||
Qwen2VLForConditionalGeneration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
MULTIMODAL_AUTO_MODEL_MAPPING = {
|
MULTIMODAL_AUTO_MODEL_MAPPING = dict(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)
|
||||||
"mllama": MllamaForConditionalGeneration,
|
|
||||||
"llama4": Llama4ForConditionalGeneration,
|
MULTIMODAL_AUTO_MODEL_MAPPING["lfm2-vl"] = AutoModelForImageTextToText
|
||||||
"llava": LlavaForConditionalGeneration,
|
|
||||||
"qwen2_vl": Qwen2VLForConditionalGeneration,
|
|
||||||
"qwen2_5_vl": Qwen2_5_VLForConditionalGeneration,
|
|
||||||
"mistral3": Mistral3ForConditionalGeneration,
|
|
||||||
"gemma3": Gemma3ForConditionalGeneration,
|
|
||||||
"gemma3n": Gemma3nForConditionalGeneration,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers import VoxtralForConditionalGeneration
|
from transformers import VoxtralForConditionalGeneration
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Model loader class implementation for loading, configuring, and patching various
|
"""
|
||||||
models.
|
Model loader class implementation for loading, configuring, and patching various models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
@@ -13,7 +13,7 @@ import peft
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from accelerate import PartialState, init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.parallelism_config import ParallelismConfig
|
from accelerate.parallelism_config import ParallelismConfig
|
||||||
from peft import (
|
from peft import (
|
||||||
PeftConfig,
|
PeftConfig,
|
||||||
@@ -22,8 +22,10 @@ from peft import (
|
|||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
prepare_model_for_kbit_training,
|
prepare_model_for_kbit_training,
|
||||||
)
|
)
|
||||||
|
from torch.distributed import DeviceMesh
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
AutoModelForImageTextToText,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
AwqConfig,
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
@@ -49,7 +51,11 @@ from axolotl.loaders.utils import (
|
|||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import get_device_count, get_device_type, get_world_size
|
from axolotl.utils.distributed import (
|
||||||
|
build_parallelism_config,
|
||||||
|
get_device_count,
|
||||||
|
get_device_type,
|
||||||
|
)
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
@@ -87,6 +93,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
use_parallel_config: bool | None = False
|
use_parallel_config: bool | None = False
|
||||||
parallelism_config: ParallelismConfig | None = None
|
parallelism_config: ParallelismConfig | None = None
|
||||||
|
device_mesh: DeviceMesh | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -202,8 +209,11 @@ class ModelLoader:
|
|||||||
self._set_device_map_config()
|
self._set_device_map_config()
|
||||||
if self.cfg.revision_of_model:
|
if self.cfg.revision_of_model:
|
||||||
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
self.model_kwargs["revision"] = self.cfg.revision_of_model
|
||||||
|
if self.cfg.use_kernels:
|
||||||
|
self.model_kwargs["use_kernels"] = self.cfg.use_kernels
|
||||||
self._set_quantization_config()
|
self._set_quantization_config()
|
||||||
self._set_attention_config()
|
self._set_attention_config()
|
||||||
|
self._check_model_requirements()
|
||||||
|
|
||||||
def _apply_post_model_load_setup(self):
|
def _apply_post_model_load_setup(self):
|
||||||
"""Configure the model after it has been loaded."""
|
"""Configure the model after it has been loaded."""
|
||||||
@@ -300,7 +310,10 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Handle DeepSpeed Zero3
|
# Handle DeepSpeed Zero3
|
||||||
if is_deepspeed_zero3_enabled():
|
if (
|
||||||
|
is_deepspeed_zero3_enabled()
|
||||||
|
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
|
||||||
|
):
|
||||||
self._set_z3_leaf_modules()
|
self._set_z3_leaf_modules()
|
||||||
|
|
||||||
# Apply gradient checkpointing if needed
|
# Apply gradient checkpointing if needed
|
||||||
@@ -405,85 +418,12 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_parallel_config_kwargs(
|
|
||||||
world_size: int,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
context_parallel_size: int = 1,
|
|
||||||
dp_shard_size: int | None = None,
|
|
||||||
dp_replicate_size: int | None = None,
|
|
||||||
is_fsdp: bool = False,
|
|
||||||
):
|
|
||||||
pc_kwargs = {}
|
|
||||||
remaining_world_size = world_size
|
|
||||||
|
|
||||||
if tensor_parallel_size and tensor_parallel_size > 1:
|
|
||||||
pc_kwargs["tp_size"] = tensor_parallel_size
|
|
||||||
remaining_world_size = remaining_world_size // tensor_parallel_size
|
|
||||||
|
|
||||||
if context_parallel_size and context_parallel_size > 1:
|
|
||||||
pc_kwargs["cp_size"] = context_parallel_size
|
|
||||||
remaining_world_size = remaining_world_size // context_parallel_size
|
|
||||||
|
|
||||||
if dp_shard_size is None and dp_replicate_size in (None, 1):
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if dp_replicate_size and dp_replicate_size > 1:
|
|
||||||
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
|
||||||
remaining_world_size = remaining_world_size // dp_replicate_size
|
|
||||||
|
|
||||||
if remaining_world_size > 1 and dp_shard_size and dp_shard_size > 1:
|
|
||||||
if not is_fsdp:
|
|
||||||
raise ValueError(
|
|
||||||
"dp_shard_size was configured without a corresponding fsdp_config! "
|
|
||||||
"Please ensure you have configured FSDP using fsdp_config."
|
|
||||||
)
|
|
||||||
pc_kwargs["dp_shard_size"] = dp_shard_size
|
|
||||||
remaining_world_size = remaining_world_size // dp_shard_size
|
|
||||||
if remaining_world_size > 1 and "dp_replicate_size" not in pc_kwargs:
|
|
||||||
pc_kwargs["dp_replicate_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
if "dp_shard_size" not in pc_kwargs and is_fsdp:
|
|
||||||
pc_kwargs["dp_shard_size"] = remaining_world_size
|
|
||||||
remaining_world_size = 1
|
|
||||||
|
|
||||||
if remaining_world_size > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"The configured parallelisms are incompatible with the current world size ({get_world_size()})!\n"
|
|
||||||
f"{pc_kwargs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return pc_kwargs
|
|
||||||
|
|
||||||
def _set_parallel_config(self):
|
def _set_parallel_config(self):
|
||||||
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||||
pc_kwargs = ModelLoader._get_parallel_config_kwargs(
|
parallelism_config, device_mesh = build_parallelism_config(self.cfg)
|
||||||
get_world_size(),
|
if parallelism_config:
|
||||||
self.cfg.tensor_parallel_size,
|
self.parallelism_config = parallelism_config
|
||||||
self.cfg.context_parallel_size,
|
self.device_mesh = device_mesh
|
||||||
self.cfg.dp_shard_size,
|
|
||||||
self.cfg.dp_replicate_size,
|
|
||||||
bool(self.cfg.fsdp or self.cfg.fsdp_config),
|
|
||||||
)
|
|
||||||
|
|
||||||
if pc_kwargs:
|
|
||||||
self.parallelism_config = ParallelismConfig(
|
|
||||||
**pc_kwargs,
|
|
||||||
)
|
|
||||||
device_mesh = self.parallelism_config.build_device_mesh("cuda")
|
|
||||||
partial_state = PartialState()
|
|
||||||
# fmt: off
|
|
||||||
partial_state._shared_state["parallelism_config"] = ( # pylint: disable=protected-access
|
|
||||||
self.parallelism_config
|
|
||||||
)
|
|
||||||
partial_state._shared_state["device_mesh"] = ( # pylint: disable=protected-access
|
|
||||||
device_mesh
|
|
||||||
)
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def _set_auto_model_loader(self):
|
def _set_auto_model_loader(self):
|
||||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
@@ -494,6 +434,8 @@ class ModelLoader:
|
|||||||
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
|
||||||
self.model_config.model_type, AutoModelForVision2Seq
|
self.model_config.model_type, AutoModelForVision2Seq
|
||||||
)
|
)
|
||||||
|
if isinstance(self.auto_model_loader, str):
|
||||||
|
self.auto_model_loader = AutoModelForImageTextToText
|
||||||
|
|
||||||
def _set_device_map_config(self):
|
def _set_device_map_config(self):
|
||||||
"""Setup `device_map` according to config"""
|
"""Setup `device_map` according to config"""
|
||||||
@@ -565,8 +507,17 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_quantization_config(self):
|
def _set_quantization_config(self):
|
||||||
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
"""Set up quantization config (bitsandbytes, awq, gptq, etc.)"""
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
if self.cfg.model_quantization_config == "Mxfp4Config":
|
||||||
|
from transformers import Mxfp4Config
|
||||||
|
|
||||||
|
mxfp4_kwargs = {}
|
||||||
|
if self.cfg.model_quantization_config_kwargs:
|
||||||
|
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||||
|
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||||
|
else:
|
||||||
|
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||||
|
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if not hasattr(self.model_config, "quantization_config"):
|
if not hasattr(self.model_config, "quantization_config"):
|
||||||
@@ -601,7 +552,9 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**self.model_config.quantization_config
|
**self.model_config.quantization_config
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
||||||
|
"load_in_4bit", False
|
||||||
|
):
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -627,7 +580,9 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
|
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
||||||
|
"load_in_8bit", False
|
||||||
|
):
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
}
|
}
|
||||||
@@ -648,7 +603,9 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_attention_config(self):
|
def _set_attention_config(self):
|
||||||
"""Sample packing uses custom FA2 patch"""
|
"""Sample packing uses custom FA2 patch"""
|
||||||
if self.cfg.flex_attention:
|
if self.cfg.attn_implementation:
|
||||||
|
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||||
|
elif self.cfg.flex_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flex_attention"
|
"flex_attention"
|
||||||
@@ -675,6 +632,16 @@ class ModelLoader:
|
|||||||
if self.cfg.low_cpu_mem_usage:
|
if self.cfg.low_cpu_mem_usage:
|
||||||
self.model_kwargs["low_cpu_mem_usage"] = True
|
self.model_kwargs["low_cpu_mem_usage"] = True
|
||||||
|
|
||||||
|
def _check_model_requirements(self):
|
||||||
|
if self.cfg.model_config_type in ["lfm2-vl", "lfm2"]:
|
||||||
|
from transformers.utils.import_utils import is_causal_conv1d_available
|
||||||
|
|
||||||
|
if is_causal_conv1d_available():
|
||||||
|
raise ImportError(
|
||||||
|
"The 'causal-conv1d' package is installed but causes compatibility issues with LFM2 models. "
|
||||||
|
"Please uninstall it by running: `pip uninstall -y causal-conv1d`"
|
||||||
|
)
|
||||||
|
|
||||||
def _configure_zero3_memory_efficient_loading(
|
def _configure_zero3_memory_efficient_loading(
|
||||||
self,
|
self,
|
||||||
) -> HfTrainerDeepSpeedConfig | None:
|
) -> HfTrainerDeepSpeedConfig | None:
|
||||||
@@ -714,6 +681,23 @@ class ModelLoader:
|
|||||||
|
|
||||||
return hf_ds_cfg
|
return hf_ds_cfg
|
||||||
|
|
||||||
|
def _load_model_from_config(self) -> PreTrainedModel:
|
||||||
|
"""Load model with random initialization using from_config."""
|
||||||
|
if self.auto_model_loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
|
||||||
|
return self.auto_model_loader.from_config(config=self.model_config)
|
||||||
|
return self.auto_model_loader(config=self.model_config)
|
||||||
|
|
||||||
|
def _load_model_from_pretrained(self, model_loader_class=None) -> PreTrainedModel:
|
||||||
|
"""Load model from pretrained weights."""
|
||||||
|
loader = model_loader_class or self.auto_model_loader
|
||||||
|
kwargs = {
|
||||||
|
**self.model_kwargs,
|
||||||
|
"config": self.model_config,
|
||||||
|
"trust_remote_code": self.cfg.trust_remote_code or False,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
return loader.from_pretrained(self.base_model, **kwargs)
|
||||||
|
|
||||||
def _build_model(self) -> bool:
|
def _build_model(self) -> bool:
|
||||||
"""Load model, with load strategy depending on config."""
|
"""Load model, with load strategy depending on config."""
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
@@ -721,14 +705,15 @@ class ModelLoader:
|
|||||||
if self.cfg.tensor_parallel_size > 1:
|
if self.cfg.tensor_parallel_size > 1:
|
||||||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
self.model_kwargs["tp_plan"] = "auto"
|
self.model_kwargs["tp_plan"] = "auto"
|
||||||
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
self.model_kwargs["device_mesh"] = self.device_mesh
|
||||||
if "device_map" in self.model_kwargs:
|
if "device_map" in self.model_kwargs:
|
||||||
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
|
# Don't delete device_map for QLoRA + FSDP - it was set correctly in
|
||||||
|
# _set_device_map
|
||||||
if (
|
if (
|
||||||
"device_map" in self.model_kwargs
|
"device_map" in self.model_kwargs
|
||||||
and not self.is_qlora_and_fsdp_enabled
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
@@ -737,6 +722,18 @@ class ModelLoader:
|
|||||||
elif self.is_qlora_and_fsdp_enabled:
|
elif self.is_qlora_and_fsdp_enabled:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.cfg.tensor_parallel_size <= 1
|
||||||
|
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
|
and self.cfg.fsdp_version == 2
|
||||||
|
):
|
||||||
|
# setting device_map for TP is not supported
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
if local_rank == 0:
|
||||||
|
self.model_kwargs["device_map"] = "cpu"
|
||||||
|
else:
|
||||||
|
self.model_kwargs["device_map"] = "meta"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.is_qlora_and_fsdp_enabled
|
self.is_qlora_and_fsdp_enabled
|
||||||
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
@@ -745,6 +742,11 @@ class ModelLoader:
|
|||||||
or self.cfg.qlora_sharded_model_loading
|
or self.cfg.qlora_sharded_model_loading
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
if self.cfg.reinit_weights:
|
||||||
|
LOG.warning(
|
||||||
|
"reinit_weights is not supported with sharded quantized loading. "
|
||||||
|
"Loading from pretrained weights instead."
|
||||||
|
)
|
||||||
quant_storage = self.cfg.torch_dtype
|
quant_storage = self.cfg.torch_dtype
|
||||||
quantization_config = getattr(
|
quantization_config = getattr(
|
||||||
self.model_config, "quantization_config", None
|
self.model_config, "quantization_config", None
|
||||||
@@ -760,33 +762,12 @@ class ModelLoader:
|
|||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
)
|
)
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
elif (
|
|
||||||
self.model_config.model_type in ["llama", "llama4"]
|
|
||||||
and not self.cfg.trust_remote_code
|
|
||||||
and not self.cfg.gptq
|
|
||||||
):
|
|
||||||
# Please don't remove underscore binding without reading the fn docstring.
|
|
||||||
_ = self._configure_zero3_memory_efficient_loading()
|
|
||||||
|
|
||||||
# Load model with random initialization if specified
|
|
||||||
if self.cfg.random_init_weights:
|
|
||||||
# AutoModel classes support the from_config method
|
|
||||||
if self.auto_model_loader in [
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoModelForVision2Seq,
|
|
||||||
]:
|
|
||||||
self.model = self.auto_model_loader.from_config(
|
|
||||||
config=self.model_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader(config=self.model_config)
|
|
||||||
else:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
elif self.model_type == "MambaLMHeadModel":
|
elif self.model_type == "MambaLMHeadModel":
|
||||||
|
if self.cfg.reinit_weights:
|
||||||
|
LOG.warning(
|
||||||
|
"reinit_weights is not supported with MambaLMHeadModel. "
|
||||||
|
"Loading from pretrained weights instead."
|
||||||
|
)
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -799,41 +780,27 @@ class ModelLoader:
|
|||||||
self.base_model,
|
self.base_model,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
elif (
|
|
||||||
self.model_type
|
|
||||||
and self.model_type != "AutoModelForCausalLM"
|
|
||||||
and not self.cfg.trust_remote_code
|
|
||||||
):
|
|
||||||
if self.cfg.gptq:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = getattr(transformers, self.model_type).from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
elif self.cfg.gptq:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
|
||||||
config=self.model_config,
|
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
|
||||||
**self.model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Please don't remove underscore binding without reading the fn docstring.
|
# Please don't remove underscore binding without reading the fn docstring
|
||||||
_ = self._configure_zero3_memory_efficient_loading()
|
_ = self._configure_zero3_memory_efficient_loading()
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
|
||||||
self.base_model,
|
if (
|
||||||
config=self.model_config,
|
self.model_type
|
||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
and self.model_type != "AutoModelForCausalLM"
|
||||||
**self.model_kwargs,
|
and not self.cfg.trust_remote_code
|
||||||
)
|
and not self.cfg.gptq
|
||||||
|
):
|
||||||
|
# Use model type from transformers
|
||||||
|
model_loader_class = getattr(transformers, self.model_type)
|
||||||
|
else:
|
||||||
|
# Use auto model loader (handles gptq and default cases)
|
||||||
|
model_loader_class = self.auto_model_loader
|
||||||
|
|
||||||
|
if self.cfg.reinit_weights:
|
||||||
|
self.model = self._load_model_from_config()
|
||||||
|
else:
|
||||||
|
self.model = self._load_model_from_pretrained(model_loader_class)
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|
||||||
@@ -845,6 +812,9 @@ class ModelLoader:
|
|||||||
self.model._tp_size = self.cfg.tensor_parallel_size
|
self.model._tp_size = self.cfg.tensor_parallel_size
|
||||||
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
self.model._device_mesh = self.model_kwargs["device_mesh"]
|
||||||
|
|
||||||
|
if self.cfg.experimental_skip_move_to_device is not None:
|
||||||
|
skip_move_to_device = self.cfg.experimental_skip_move_to_device
|
||||||
|
|
||||||
return skip_move_to_device
|
return skip_move_to_device
|
||||||
|
|
||||||
def _set_z3_leaf_modules(self):
|
def _set_z3_leaf_modules(self):
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class PatchManager:
|
|||||||
self._patch_llama_derived_model()
|
self._patch_llama_derived_model()
|
||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
|
self._apply_fsdp2_bnb_patches()
|
||||||
|
|
||||||
def apply_post_plugin_pre_model_load_patches(self):
|
def apply_post_plugin_pre_model_load_patches(self):
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||||
@@ -72,11 +73,19 @@ class PatchManager:
|
|||||||
self._apply_voxtral_patches()
|
self._apply_voxtral_patches()
|
||||||
|
|
||||||
def _apply_transformers_patches(self):
|
def _apply_transformers_patches(self):
|
||||||
from axolotl.monkeypatch.transformers.modeling_flash_attention_utils import (
|
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||||
patch_prepare_from_posids,
|
patch_evaluation_loop,
|
||||||
|
patch_maybe_log_save_evaluate,
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_prepare_from_posids()
|
patch_fsdp2 = (
|
||||||
|
self.cfg.torch_compile
|
||||||
|
and self.cfg.fsdp_config
|
||||||
|
and self.cfg.fsdp_version == 2
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_evaluation_loop(patch_fsdp2)
|
||||||
|
patch_maybe_log_save_evaluate()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -103,6 +112,14 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
|
if self.cfg.context_parallel_size > 1 or (
|
||||||
|
self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2"
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.accelerate.parallelism_config import (
|
||||||
|
patch_parallelism_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_parallelism_config()
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||||
|
|
||||||
@@ -260,6 +277,21 @@ class PatchManager:
|
|||||||
has_remote_code=has_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_fsdp2_bnb_patches(self):
|
||||||
|
"""Apply FSDP2 BNB patches."""
|
||||||
|
if (
|
||||||
|
self.cfg.fsdp_config
|
||||||
|
and str(self.cfg.fsdp_version) == "2"
|
||||||
|
and self.cfg.adapter == "qlora"
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||||
|
apply_init_sharded_param_patch,
|
||||||
|
apply_init_unsharded_param_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_init_sharded_param_patch()
|
||||||
|
apply_init_unsharded_param_patch()
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
from axolotl.monkeypatch.tiled_mlp import (
|
from axolotl.monkeypatch.tiled_mlp import (
|
||||||
@@ -330,31 +362,21 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
def _patch_llama_flash_attention(self, packed=False):
|
def _patch_llama_flash_attention(self):
|
||||||
"""Apply Flash Attention patches for LLaMA models."""
|
"""Apply Flash Attention patches for LLaMA models."""
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
if packed:
|
if self.cfg.s2_attention:
|
||||||
if self.cfg.device not in ["mps", "cpu"] and not self.inference:
|
|
||||||
LOG.info("patching with flash attention for sample packing")
|
|
||||||
replace_llama_attn_with_flash_attn(
|
|
||||||
packed=True,
|
|
||||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
|
||||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
|
||||||
)
|
|
||||||
elif self.cfg.s2_attention:
|
|
||||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||||
replace_llama_attn_with_flash_attn(
|
replace_llama_attn_with_flash_attn(
|
||||||
packed=False,
|
|
||||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||||
use_shifted_sparse_attn=True,
|
use_shifted_sparse_attn=True,
|
||||||
)
|
)
|
||||||
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
|
elif self.cfg.flash_attn_cross_entropy or self.cfg.flash_attn_rms_norm:
|
||||||
replace_llama_attn_with_flash_attn(
|
replace_llama_attn_with_flash_attn(
|
||||||
packed=False,
|
|
||||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||||
)
|
)
|
||||||
@@ -385,7 +407,7 @@ class PatchManager:
|
|||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self._patch_llama_flash_attention(packed=self.cfg.sample_packing)
|
self._patch_llama_flash_attention()
|
||||||
elif self.cfg.xformers_attention:
|
elif self.cfg.xformers_attention:
|
||||||
self._patch_llama_xformers_attention()
|
self._patch_llama_xformers_attention()
|
||||||
elif self.cfg.sample_packing:
|
elif self.cfg.sample_packing:
|
||||||
@@ -408,17 +430,12 @@ class PatchManager:
|
|||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
is_xformers_swiglu_available,
|
is_xformers_swiglu_available,
|
||||||
replace_llama_mlp_with_swiglu,
|
replace_llama_mlp_with_swiglu,
|
||||||
replace_llama_qkv_with_fused,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
LOG.info("Patching with SwiGLU...")
|
LOG.info("Patching with SwiGLU...")
|
||||||
replace_llama_mlp_with_swiglu(model)
|
replace_llama_mlp_with_swiglu(model)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_qkv:
|
|
||||||
LOG.info("Patching with fused QKV...")
|
|
||||||
replace_llama_qkv_with_fused(model)
|
|
||||||
|
|
||||||
def _apply_unsloth_patches(self, model):
|
def _apply_unsloth_patches(self, model):
|
||||||
"""Apply unsloth optimization patches."""
|
"""Apply unsloth optimization patches."""
|
||||||
if self.cfg.unsloth_lora_mlp:
|
if self.cfg.unsloth_lora_mlp:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import functools
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
@@ -36,25 +37,49 @@ def fsdp2_load_full_state_dict(
|
|||||||
|
|
||||||
meta_sharded_sd = model.state_dict()
|
meta_sharded_sd = model.state_dict()
|
||||||
sharded_sd = {}
|
sharded_sd = {}
|
||||||
for param_name, full_tensor in full_sd.items():
|
for param_name, sharded_meta_param in meta_sharded_sd.items():
|
||||||
sharded_meta_param = meta_sharded_sd.get(param_name)
|
full_tensor = None
|
||||||
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
if _accelerator.is_main_process:
|
||||||
|
full_tensor = full_sd[param_name]
|
||||||
|
full_tensor = full_tensor.to(sharded_meta_param.dtype)
|
||||||
|
|
||||||
if hasattr(sharded_meta_param, "device_mesh"):
|
if hasattr(sharded_meta_param, "device_mesh"):
|
||||||
|
device_mesh = sharded_meta_param.device_mesh
|
||||||
|
if _accelerator.is_main_process:
|
||||||
|
full_tensor = full_tensor.to(device_mesh.device_type)
|
||||||
|
else:
|
||||||
|
full_tensor = torch.empty(
|
||||||
|
sharded_meta_param.size(),
|
||||||
|
device=device_mesh.device_type,
|
||||||
|
dtype=sharded_meta_param.dtype,
|
||||||
|
)
|
||||||
sharded_param = distribute_tensor(
|
sharded_param = distribute_tensor(
|
||||||
full_tensor,
|
full_tensor,
|
||||||
sharded_meta_param.device_mesh,
|
device_mesh,
|
||||||
sharded_meta_param.placements,
|
sharded_meta_param.placements,
|
||||||
src_data_rank=0,
|
src_data_rank=0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sharded_param = full_tensor
|
# Non-sharded parameters
|
||||||
|
if _accelerator.is_main_process:
|
||||||
|
sharded_param = full_tensor.to(torch.device("cuda"))
|
||||||
|
else:
|
||||||
|
# broadcast manually
|
||||||
|
sharded_param = torch.empty_like(
|
||||||
|
sharded_meta_param,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
dtype=sharded_meta_param.dtype,
|
||||||
|
)
|
||||||
|
dist.broadcast(sharded_param, src=0)
|
||||||
|
|
||||||
if offload_to_cpu:
|
if offload_to_cpu:
|
||||||
sharded_param = sharded_param.cpu()
|
sharded_param = sharded_param.cpu()
|
||||||
|
|
||||||
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||||
|
|
||||||
del full_tensor
|
del full_tensor
|
||||||
full_sd[param_name] = None
|
full_sd[param_name] = None
|
||||||
|
|
||||||
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
|
|||||||
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
77
src/axolotl/monkeypatch/accelerate/parallelism_config.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
workaround to allow parallelism config for pure CP
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from accelerate import DistributedType
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_accelerator(self, accelerator):
|
||||||
|
_warnings = set()
|
||||||
|
if not accelerator.multi_device and self.total_size == 1:
|
||||||
|
# No distributed setup, valid parallelism config
|
||||||
|
return
|
||||||
|
|
||||||
|
# We need this to ensure DDP works
|
||||||
|
if self.total_size == 1:
|
||||||
|
self._set_size("dp_replicate", accelerator.num_processes)
|
||||||
|
|
||||||
|
if self.total_size != accelerator.num_processes:
|
||||||
|
raise ValueError(
|
||||||
|
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
||||||
|
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
||||||
|
f"dp_shard_size/tp_size/cp_size."
|
||||||
|
)
|
||||||
|
|
||||||
|
# allow parallelism config when not using fsdp if using pure context parallelism
|
||||||
|
allow_parallelism_config = False
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.cp_size > 1 # pylint: disable=chained-comparison
|
||||||
|
and self.dp_shard_size <= 1
|
||||||
|
and os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true"
|
||||||
|
):
|
||||||
|
allow_parallelism_config = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.total_size > 1
|
||||||
|
and not allow_parallelism_config
|
||||||
|
and not (accelerator.is_fsdp2 or accelerator.multi_device)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for parallelism, size in self._sizes.items():
|
||||||
|
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
||||||
|
_warnings.add(
|
||||||
|
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
if _warnings and accelerator.is_main_process:
|
||||||
|
warnings.warn(
|
||||||
|
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_is_fsdp2(self) -> bool:
|
||||||
|
"""
|
||||||
|
Patched version of is_fsdp2 that guards against a None fsdp_plugin.
|
||||||
|
"""
|
||||||
|
# The new logic checks if fsdp_plugin exists before accessing its attributes
|
||||||
|
return (
|
||||||
|
self.distributed_type == DistributedType.FSDP
|
||||||
|
and self.fsdp_plugin
|
||||||
|
and self.fsdp_plugin.fsdp_version == 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_parallelism_config():
|
||||||
|
from accelerate.accelerator import AcceleratorState, ParallelismConfig
|
||||||
|
|
||||||
|
ParallelismConfig._validate_accelerator = _validate_accelerator
|
||||||
|
AcceleratorState.is_fsdp2 = property(patched_is_fsdp2)
|
||||||
144
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
144
src/axolotl/monkeypatch/fsdp2_qlora.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
Monkeypatch to add Params4bit support to FSDP2. This enables QLoRA + FSDP2, as well as
|
||||||
|
our LoRA / QLoRA Triton kernels to work with FSDP2.
|
||||||
|
|
||||||
|
This patch modifies the _init_sharded_param method in FSDPParam to handle bitsandbytes
|
||||||
|
Params4bit parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def apply_init_sharded_param_patch():
|
||||||
|
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
# Get original source
|
||||||
|
original_source = inspect.getsource(FSDPParam._init_sharded_param)
|
||||||
|
original_source, _ = detab_code(original_source)
|
||||||
|
|
||||||
|
# Define the replacement
|
||||||
|
original_param_creation = """ self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||||
|
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||||
|
|
||||||
|
patched_param_creation = """ import bitsandbytes as bnb
|
||||||
|
if isinstance(param, bnb.nn.modules.Params4bit):
|
||||||
|
self.sharded_param = bnb.nn.modules.Params4bit(
|
||||||
|
data=sharded_param,
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
quant_state=param.quant_state,
|
||||||
|
blocksize=param.blocksize,
|
||||||
|
compress_statistics=param.compress_statistics,
|
||||||
|
quant_type=param.quant_type,
|
||||||
|
quant_storage=param.quant_storage,
|
||||||
|
module=param.module,
|
||||||
|
bnb_quantized=param.bnb_quantized,
|
||||||
|
)
|
||||||
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
|
else:
|
||||||
|
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||||
|
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||||
|
|
||||||
|
# Apply the replacement
|
||||||
|
if original_param_creation in original_source:
|
||||||
|
patched_source = original_source.replace(
|
||||||
|
original_param_creation, patched_param_creation
|
||||||
|
)
|
||||||
|
patched_source = patched_source.replace(
|
||||||
|
"def _init_sharded_param(",
|
||||||
|
"def patched_init_sharded_param(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load necessary imports
|
||||||
|
module_name = FSDPParam.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in patched_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
# Replace the method
|
||||||
|
FSDPParam._init_sharded_param = patched_init_sharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
LOG.info("Successfully applied FSDP _init_sharded_param patch")
|
||||||
|
else:
|
||||||
|
LOG.warning("Could not find target code for _init_sharded_param patching")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_init_unsharded_param_patch():
|
||||||
|
"""Apply patch to FSDPParam.init_unsharded_param to support Params4bit."""
|
||||||
|
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||||
|
|
||||||
|
# Get original source
|
||||||
|
original_source = inspect.getsource(FSDPParam.init_unsharded_param)
|
||||||
|
original_source, _ = detab_code(original_source)
|
||||||
|
|
||||||
|
# Define the replacement
|
||||||
|
original_param_creation = """ self._unsharded_param = nn.Parameter(
|
||||||
|
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||||
|
)"""
|
||||||
|
|
||||||
|
patched_param_creation = """ import bitsandbytes as bnb
|
||||||
|
local_tensor = self.sharded_param._local_tensor
|
||||||
|
if isinstance(local_tensor, bnb.nn.modules.Params4bit):
|
||||||
|
self._unsharded_param = bnb.nn.modules.Params4bit(
|
||||||
|
data=unsharded_param,
|
||||||
|
requires_grad=self.sharded_param.requires_grad,
|
||||||
|
quant_state=local_tensor.quant_state,
|
||||||
|
blocksize=local_tensor.blocksize,
|
||||||
|
compress_statistics=local_tensor.compress_statistics,
|
||||||
|
quant_type=local_tensor.quant_type,
|
||||||
|
quant_storage=local_tensor.quant_storage,
|
||||||
|
module=local_tensor.module,
|
||||||
|
bnb_quantized=local_tensor.bnb_quantized,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._unsharded_param = nn.Parameter(
|
||||||
|
unsharded_param, requires_grad=self.sharded_param.requires_grad
|
||||||
|
)"""
|
||||||
|
|
||||||
|
# Apply the replacement
|
||||||
|
if original_param_creation in original_source:
|
||||||
|
patched_source = original_source.replace(
|
||||||
|
original_param_creation, patched_param_creation
|
||||||
|
)
|
||||||
|
patched_source = patched_source.replace(
|
||||||
|
"def init_unsharded_param(",
|
||||||
|
"def patched_init_unsharded_param(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load necessary imports
|
||||||
|
module_name = FSDPParam.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in patched_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(patched_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
# Replace the method
|
||||||
|
FSDPParam.init_unsharded_param = patched_init_unsharded_param # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
LOG.info("Successfully applied FSDP init_unsharded_param patch")
|
||||||
|
else:
|
||||||
|
LOG.warning("Could not find target code for patching")
|
||||||
@@ -3,39 +3,26 @@
|
|||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaAttention,
|
|
||||||
)
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
|
||||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
|
||||||
)
|
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from axolotl.monkeypatch.utils import set_module_name
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
flash_attn_kvpacked_func,
|
|
||||||
flash_attn_varlen_kvpacked_func,
|
|
||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from flash_attn.flash_attn_interface import (
|
|
||||||
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
|
||||||
)
|
|
||||||
from flash_attn.flash_attn_interface import (
|
from flash_attn.flash_attn_interface import (
|
||||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
@@ -82,19 +69,6 @@ def replace_llama_mlp_with_swiglu(model):
|
|||||||
set_module_name(model, name, mlp)
|
set_module_name(model, name, mlp)
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_qkv_with_fused(model):
|
|
||||||
for name, module in model.named_modules():
|
|
||||||
if isinstance(module, LlamaAttention):
|
|
||||||
qkv = FusedAttention(
|
|
||||||
module.config,
|
|
||||||
module.q_proj,
|
|
||||||
module.k_proj,
|
|
||||||
module.v_proj,
|
|
||||||
module.o_proj,
|
|
||||||
)
|
|
||||||
set_module_name(model, name, qkv)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_fa_llama_cross_entropy():
|
def patch_fa_llama_cross_entropy():
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||||
@@ -142,7 +116,6 @@ def patch_llama_rms_norm():
|
|||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
rms_norm: Optional[bool] = False,
|
rms_norm: Optional[bool] = False,
|
||||||
use_shifted_sparse_attn: Optional[bool] = False,
|
use_shifted_sparse_attn: Optional[bool] = False,
|
||||||
@@ -154,16 +127,6 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||||
flashattn_forward_with_s2attn
|
flashattn_forward_with_s2attn
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
|
||||||
flashattn_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
if packed:
|
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
|
||||||
llama_model_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
@@ -174,49 +137,6 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
patch_llama_rms_norm()
|
patch_llama_rms_norm()
|
||||||
|
|
||||||
|
|
||||||
class FusedAttention(LlamaAttention):
|
|
||||||
"""
|
|
||||||
Fused QKV Attention layer for incrementally improved training efficiency
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config,
|
|
||||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
|
||||||
):
|
|
||||||
super().__init__(config)
|
|
||||||
self.config = config
|
|
||||||
self.init_device = next(iter(q.state_dict().values())).device
|
|
||||||
|
|
||||||
# define equivalent fused qkv projection
|
|
||||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
|
||||||
self.qkv_proj = torch.nn.Linear(
|
|
||||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
|
||||||
)
|
|
||||||
self.o_proj = o
|
|
||||||
|
|
||||||
# overwrite initialized weights with pretrained weights
|
|
||||||
self.qkv_proj.weight.data = torch.cat(
|
|
||||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
def _post_training(self, model, name):
|
|
||||||
q_proj, k_proj, v_proj = torch.split(
|
|
||||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
new_attn = LlamaAttention(self.config)
|
|
||||||
new_attn.q_proj.weight.data = q_proj
|
|
||||||
new_attn.k_proj.weight.data = k_proj
|
|
||||||
new_attn.v_proj.weight.data = v_proj
|
|
||||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
|
||||||
|
|
||||||
set_module_name(model, name, new_attn)
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -355,576 +275,3 @@ def flashattn_forward_with_s2attn(
|
|||||||
.reshape(bsz, q_len, nheads, self.head_dim)
|
.reshape(bsz, q_len, nheads, self.head_dim)
|
||||||
)
|
)
|
||||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def flashattn_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""Input shape: Batch x Time x Channel
|
|
||||||
|
|
||||||
attention_mask: [bsz, q_len]
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
|
||||||
self.pretraining_tp = 1
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
key_value_slicing = (
|
|
||||||
self.num_key_value_heads * self.head_dim
|
|
||||||
) // self.pretraining_tp
|
|
||||||
query_slices = self.q_proj.weight.split(
|
|
||||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
|
||||||
)
|
|
||||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
||||||
|
|
||||||
query_states = [
|
|
||||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
query_states = torch.cat(query_states, dim=-1)
|
|
||||||
|
|
||||||
key_states = [
|
|
||||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
key_states = torch.cat(key_states, dim=-1)
|
|
||||||
|
|
||||||
value_states = [
|
|
||||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
|
||||||
]
|
|
||||||
value_states = torch.cat(value_states, dim=-1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if isinstance(self, FusedAttention):
|
|
||||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
|
||||||
self.out_features, dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
# [bsz, q_len, nh, hd]
|
|
||||||
# [bsz, nh, q_len, hd]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
warnings.warn(
|
|
||||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
#
|
|
||||||
# flash-attn v2 start
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# during training q,k,v always have same seqlen
|
|
||||||
assert key_states.shape == query_states.shape
|
|
||||||
is_causal = True
|
|
||||||
else:
|
|
||||||
# turn off FA causal mask after first inference autoregressive iteration
|
|
||||||
# only on first autoregressive step q,k,v have same seqlen
|
|
||||||
is_causal = key_states.shape == query_states.shape
|
|
||||||
|
|
||||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
|
||||||
# special handling using sample packing
|
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
elif query_states.shape == key_states.shape:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
qkvpacked=True,
|
|
||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=(
|
|
||||||
attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
max_seqlen_q,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
else:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
if attention_mask is None or attention_mask.all().item():
|
|
||||||
output = flash_attn_kvpacked_func(
|
|
||||||
query_states,
|
|
||||||
torch.stack([key_states, value_states], 2),
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
causal=is_causal,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
output_pad_fn,
|
|
||||||
) = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
kvpacked=True,
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=(
|
|
||||||
attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
|
|
||||||
attn_output = output
|
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
|
||||||
|
|
||||||
#
|
|
||||||
# flash-attn v2 end
|
|
||||||
#
|
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
|
||||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
|
||||||
o_proj_slices = self.o_proj.weight.split(
|
|
||||||
self.hidden_size // self.pretraining_tp, dim=1
|
|
||||||
)
|
|
||||||
attn_output = sum(
|
|
||||||
F.linear(attn_output[i], o_proj_slices[i])
|
|
||||||
for i in range(self.pretraining_tp)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
|
||||||
def generate_qkv(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
query_padding_mask=None,
|
|
||||||
key_padding_mask=None,
|
|
||||||
kvpacked=False,
|
|
||||||
qkvpacked=False,
|
|
||||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
q: (batch_size, seqlen_q, nheads, d)
|
|
||||||
k: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
v: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
query_padding_mask: (batch_size, seqlen), bool
|
|
||||||
key_padding_mask: (batch_size, seqlen), bool
|
|
||||||
"""
|
|
||||||
assert not (kvpacked and qkvpacked)
|
|
||||||
batch_size, seqlen_q, nheads, d = q.shape
|
|
||||||
_, seqlen_k, nheads_k, _ = k.shape
|
|
||||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
|
||||||
q, query_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def output_pad_fn(output_unpad):
|
|
||||||
return pad_input( # noqa: E731
|
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_q,
|
|
||||||
step=seqlen_q,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=q_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_q = seqlen_q
|
|
||||||
|
|
||||||
def output_pad_fn(output_unpad):
|
|
||||||
return rearrange( # noqa: E731
|
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
|
||||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
|
||||||
else:
|
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_k,
|
|
||||||
step=seqlen_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=k_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_k = seqlen_k
|
|
||||||
|
|
||||||
if qkvpacked:
|
|
||||||
assert nheads == nheads_k
|
|
||||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
|
||||||
qkv = torch.stack([q, k, v], dim=2)
|
|
||||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
|
||||||
|
|
||||||
if kvpacked:
|
|
||||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
kv,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
k_unpad,
|
|
||||||
v_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[ # pylint: disable=unused-argument
|
|
||||||
torch.LongTensor
|
|
||||||
] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
if input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
||||||
)
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
cu_seqlens = None
|
|
||||||
max_seqlen = None
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length,
|
|
||||||
seq_length + past_key_values_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
|
||||||
cu_seqlens = cu_seqlens.squeeze()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
# embed positions
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(
|
|
||||||
(batch_size, seq_length_with_past),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
padding_mask = None
|
|
||||||
else:
|
|
||||||
if 0 in attention_mask:
|
|
||||||
padding_mask = attention_mask
|
|
||||||
else:
|
|
||||||
padding_mask = None
|
|
||||||
|
|
||||||
attention_mask = (
|
|
||||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
transformers.logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
|
||||||
def custom_forward(*inputs):
|
|
||||||
# None for past_key_value
|
|
||||||
return module(
|
|
||||||
*inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return custom_forward
|
|
||||||
|
|
||||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
||||||
create_custom_forward(decoder_layer),
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
None,
|
|
||||||
padding_mask,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|
||||||
"""
|
|
||||||
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
padding_mask: Optional[torch.LongTensor] = None,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[
|
|
||||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
||||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
padding_mask=padding_mask,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|||||||
@@ -156,6 +156,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|||||||
|
|
||||||
return Llama4TextAttention
|
return Llama4TextAttention
|
||||||
|
|
||||||
|
if model_type == "mistral3":
|
||||||
|
from transformers.models.mistral.modeling_mistral import MistralAttention
|
||||||
|
|
||||||
|
return MistralAttention
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and attention class
|
# Dynamically import the module and attention class
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
@@ -390,7 +395,6 @@ def apply_lora_kernel_patches(
|
|||||||
]
|
]
|
||||||
can_patch_qkv = all(
|
can_patch_qkv = all(
|
||||||
hasattr(module, "lora_A")
|
hasattr(module, "lora_A")
|
||||||
and getattr(module, "base_layer", module).bias is None
|
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
@@ -400,7 +404,8 @@ def apply_lora_kernel_patches(
|
|||||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention QKV projections - requires LoRA adapters with no bias"
|
"Cannot patch some attention QKV projections - requires LoRA "
|
||||||
|
"adapters and no lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
if cfg.lora_o_kernel:
|
if cfg.lora_o_kernel:
|
||||||
# Output patching
|
# Output patching
|
||||||
@@ -409,7 +414,6 @@ def apply_lora_kernel_patches(
|
|||||||
]
|
]
|
||||||
can_patch_o = all(
|
can_patch_o = all(
|
||||||
hasattr(module, "lora_A")
|
hasattr(module, "lora_A")
|
||||||
and getattr(module, "base_layer", module).bias is None
|
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||||
for module in layer_modules
|
for module in layer_modules
|
||||||
)
|
)
|
||||||
@@ -418,14 +422,14 @@ def apply_lora_kernel_patches(
|
|||||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention output projection - requires LoRA adapters with no bias"
|
"Cannot patch some attention output projection - requires LoRA "
|
||||||
|
"adapters and no lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_mlp_kernel:
|
||||||
# MLP patching
|
# MLP patching
|
||||||
can_patch_mlp = all(
|
can_patch_mlp = all(
|
||||||
hasattr(proj, "lora_A")
|
hasattr(proj, "lora_A")
|
||||||
and getattr(proj, "base_layer", proj).bias is None
|
|
||||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||||
for proj in (gate_proj, up_proj, down_proj)
|
for proj in (gate_proj, up_proj, down_proj)
|
||||||
)
|
)
|
||||||
@@ -435,7 +439,8 @@ def apply_lora_kernel_patches(
|
|||||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some MLP layers - requires LoRA adapters with no bias"
|
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
||||||
|
"lora_magnitude_vector (DoRA)"
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.setLevel(original_level)
|
LOG.setLevel(original_level)
|
||||||
|
|||||||
@@ -3,53 +3,14 @@
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
||||||
flash_attn_kvpacked_func,
|
|
||||||
flash_attn_varlen_kvpacked_func,
|
|
||||||
flash_attn_varlen_qkvpacked_func,
|
|
||||||
)
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
MistralAttention as OriginalMistralAttention,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
apply_rotary_pos_emb,
|
|
||||||
repeat_kv,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def replace_mistral_attn_with_flash_attn(
|
|
||||||
packed: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
|
||||||
_prepare_decoder_attention_mask
|
|
||||||
)
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
|
||||||
flashattn_forward
|
|
||||||
)
|
|
||||||
if packed:
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
|
||||||
MistralDecoderLayer
|
|
||||||
)
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
|
||||||
mistral_model_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mistral_cross_entropy():
|
def patch_mistral_cross_entropy():
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
@@ -57,604 +18,3 @@ def patch_mistral_cross_entropy():
|
|||||||
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||||
CrossEntropyLoss, inplace_backward=True
|
CrossEntropyLoss, inplace_backward=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _make_sliding_window_causal_mask(
|
|
||||||
bsz: int,
|
|
||||||
tgt_len: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
past_key_values_length: int = 0,
|
|
||||||
sliding_window: int = 4096,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Make causal mask used for sliding window attention
|
|
||||||
"""
|
|
||||||
tensor = torch.full(
|
|
||||||
(tgt_len, tgt_len),
|
|
||||||
fill_value=1,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
mask = torch.tril(tensor, diagonal=0)
|
|
||||||
# make the mask banded to account for sliding window
|
|
||||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
|
||||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
|
||||||
mask = torch.log(mask).to(dtype)
|
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
|
||||||
mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
|
||||||
),
|
|
||||||
mask,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
return mask[None, None, :, :].expand(
|
|
||||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
|
||||||
def _prepare_decoder_attention_mask(
|
|
||||||
self,
|
|
||||||
attention_mask,
|
|
||||||
input_shape,
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window,
|
|
||||||
): # pylint: disable=unused-argument
|
|
||||||
# [bsz, seq_len]
|
|
||||||
if attention_mask is None or sliding_window is None:
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
|
||||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
|
||||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
|
||||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
|
||||||
bsz=input_shape[0],
|
|
||||||
tgt_len=input_shape[1],
|
|
||||||
dtype=inputs_embeds.dtype,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
sliding_window=sliding_window,
|
|
||||||
)
|
|
||||||
attention_mask = attention_mask + sliding_window_mask
|
|
||||||
else:
|
|
||||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
|
||||||
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
def flashattn_forward(
|
|
||||||
self: OriginalMistralAttention,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
key_states = key_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
value_states = value_states.view(
|
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
||||||
).transpose(1, 2)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
use_sliding_windows = (
|
|
||||||
getattr(self.config, "sliding_window") is not None
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_sliding_windows:
|
|
||||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
|
||||||
else:
|
|
||||||
window_size = (-1, -1)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
|
||||||
if (
|
|
||||||
hasattr(self.config, "sliding_window")
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
):
|
|
||||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
|
||||||
|
|
||||||
past_key = past_key_value[0]
|
|
||||||
past_value = past_key_value[1]
|
|
||||||
|
|
||||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
|
|
||||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
|
||||||
f" {past_key.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_value = (past_key, past_value) if use_cache else None
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# during training q,k,v always have same seqlen
|
|
||||||
assert key_states.shape == query_states.shape
|
|
||||||
is_causal = True
|
|
||||||
else:
|
|
||||||
# turn off FA causal mask after first inference autoregressive iteration
|
|
||||||
# only on first autoregressive step q,k,v have same seqlen
|
|
||||||
is_causal = key_states.shape == query_states.shape
|
|
||||||
|
|
||||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
|
||||||
|
|
||||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
|
||||||
# special handling using sample packing
|
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
elif query_states.shape == key_states.shape:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
qkvpacked=True,
|
|
||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=(
|
|
||||||
attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
max_seqlen_q,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
else:
|
|
||||||
query_states = query_states.transpose(1, 2)
|
|
||||||
key_states = key_states.transpose(1, 2)
|
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
if attention_mask is None or attention_mask.all().item():
|
|
||||||
output = flash_attn_kvpacked_func(
|
|
||||||
query_states,
|
|
||||||
torch.stack([key_states, value_states], 2),
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
output_pad_fn,
|
|
||||||
) = generate_qkv(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
kvpacked=True,
|
|
||||||
key_padding_mask=attention_mask,
|
|
||||||
query_padding_mask=(
|
|
||||||
attention_mask[:, -query_states.size(1) :]
|
|
||||||
if attention_mask is not None
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if q_unpad.dtype != kv_unpad.dtype:
|
|
||||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
|
||||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
dropout_p=dropout_rate,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=is_causal,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
|
||||||
output = output_pad_fn(output_unpad)
|
|
||||||
|
|
||||||
attn_output = output
|
|
||||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
if not output_attentions:
|
|
||||||
attn_weights = None
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|
||||||
|
|
||||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
|
||||||
def generate_qkv(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
query_padding_mask=None,
|
|
||||||
key_padding_mask=None,
|
|
||||||
kvpacked=False,
|
|
||||||
qkvpacked=False,
|
|
||||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
q: (batch_size, seqlen_q, nheads, d)
|
|
||||||
k: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
v: (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
query_padding_mask: (batch_size, seqlen), bool
|
|
||||||
key_padding_mask: (batch_size, seqlen), bool
|
|
||||||
"""
|
|
||||||
assert not (kvpacked and qkvpacked)
|
|
||||||
batch_size, seqlen_q, nheads, d = q.shape
|
|
||||||
_, seqlen_k, nheads_k, _ = k.shape
|
|
||||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
|
||||||
q, query_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
def output_pad_fn(output_unpad):
|
|
||||||
return pad_input( # noqa: E731
|
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_q,
|
|
||||||
step=seqlen_q,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=q_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_q = seqlen_q
|
|
||||||
|
|
||||||
def output_pad_fn(output_unpad):
|
|
||||||
return rearrange( # noqa: E731
|
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
|
||||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
|
||||||
else:
|
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_k,
|
|
||||||
step=seqlen_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=k_unpad.device,
|
|
||||||
)
|
|
||||||
max_seqlen_k = seqlen_k
|
|
||||||
|
|
||||||
if qkvpacked:
|
|
||||||
assert nheads == nheads_k
|
|
||||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
|
||||||
qkv = torch.stack([q, k, v], dim=2)
|
|
||||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
|
||||||
|
|
||||||
if kvpacked:
|
|
||||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
kv_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
kv,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
q_unpad,
|
|
||||||
k_unpad,
|
|
||||||
v_unpad,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
output_pad_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def mistral_model_forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[ # pylint: disable=unused-argument
|
|
||||||
torch.LongTensor
|
|
||||||
] = None,
|
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
||||||
output_attentions = (
|
|
||||||
output_attentions
|
|
||||||
if output_attentions is not None
|
|
||||||
else self.config.output_attentions
|
|
||||||
)
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states
|
|
||||||
if output_hidden_states is not None
|
|
||||||
else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
|
||||||
)
|
|
||||||
if input_ids is not None:
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
elif inputs_embeds is not None:
|
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
|
||||||
)
|
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
|
||||||
past_key_values_length = 0
|
|
||||||
|
|
||||||
if past_key_values is not None:
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
||||||
|
|
||||||
cu_seqlens = None
|
|
||||||
max_seqlen = None
|
|
||||||
if position_ids is None:
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
||||||
position_ids = torch.arange(
|
|
||||||
past_key_values_length,
|
|
||||||
seq_length + past_key_values_length,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
||||||
else:
|
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
|
||||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
|
||||||
cu_seqlens = cu_seqlens.squeeze()
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
# embed positions
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones(
|
|
||||||
(batch_size, seq_length_with_past),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
)
|
|
||||||
attention_mask = (
|
|
||||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
|
||||||
attention_mask,
|
|
||||||
(batch_size, seq_length),
|
|
||||||
inputs_embeds,
|
|
||||||
past_key_values_length,
|
|
||||||
sliding_window=self.config.sliding_window,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
transformers.logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
# decoder layers
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
all_self_attns = () if output_attentions else None
|
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
|
|
||||||
for idx, decoder_layer in enumerate(self.layers):
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
layer_outputs = (
|
|
||||||
self._gradient_checkpointing_func( # pylint: disable=protected-access
|
|
||||||
decoder_layer.__call__,
|
|
||||||
hidden_states,
|
|
||||||
attention_mask,
|
|
||||||
position_ids,
|
|
||||||
past_key_value,
|
|
||||||
output_attentions,
|
|
||||||
None,
|
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
layer_outputs = decoder_layer(
|
|
||||||
hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_self_attns += (layer_outputs[1],)
|
|
||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
|
||||||
|
|
||||||
# add hidden states from the last decoder layer
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(
|
|
||||||
v
|
|
||||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
|
||||||
if v is not None
|
|
||||||
)
|
|
||||||
return BaseModelOutputWithPast(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
past_key_values=next_cache,
|
|
||||||
hidden_states=all_hidden_states,
|
|
||||||
attentions=all_self_attns,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
|
||||||
"""
|
|
||||||
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[
|
|
||||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
||||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# Self Attention
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
|
||||||
residual = hidden_states
|
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
outputs += (present_key_value,)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
"smollm3",
|
"smollm3",
|
||||||
|
"gpt_oss",
|
||||||
|
"arcee",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,12 +20,15 @@ from ring_flash_attn import ring_flash_attn_func
|
|||||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
||||||
|
|
||||||
try:
|
try: # pylint: disable=duplicate-code
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from transformers.modeling_flash_attention_utils import (
|
try:
|
||||||
_flash_supports_window_size as _flash_supports_window,
|
from transformers.modeling_flash_attention_utils import (
|
||||||
)
|
_flash_supports_window_size as _flash_supports_window,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
_flash_supports_window = True
|
||||||
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
|
|||||||
@@ -15,12 +15,15 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import DeviceMesh
|
from torch.distributed import DeviceMesh
|
||||||
|
|
||||||
try:
|
try: # pylint: disable=duplicate-code
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from transformers.modeling_flash_attention_utils import (
|
try:
|
||||||
_flash_supports_window_size as _flash_supports_window,
|
from transformers.modeling_flash_attention_utils import (
|
||||||
)
|
_flash_supports_window_size as _flash_supports_window,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
_flash_supports_window = True
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
"""
|
|
||||||
fix for FSDP2 evals when using torch.compile
|
|
||||||
"""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from transformers import Trainer
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
ORIGINAL_TRAINER_CODE = """
|
|
||||||
model.eval()
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATCHED_TRAINER_CODE = """
|
|
||||||
if hasattr(model, "eval") and callable(model.eval):
|
|
||||||
self.model.eval()
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def get_evaluation_loop_code() -> str:
|
|
||||||
training_loop = inspect.getsource(Trainer.evaluation_loop)
|
|
||||||
return training_loop
|
|
||||||
|
|
||||||
|
|
||||||
def check_evaluation_loop_is_patchable() -> bool:
|
|
||||||
eval_loop = get_evaluation_loop_code()
|
|
||||||
eval_loop, _ = detab_code(eval_loop)
|
|
||||||
return ORIGINAL_TRAINER_CODE in eval_loop
|
|
||||||
|
|
||||||
|
|
||||||
def patch_evaluation_loop_for_fsdp2():
|
|
||||||
"""
|
|
||||||
monkeypatch for fixing the eval loop for fsdp2 with torch.compile
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
evaluation_loop = get_evaluation_loop_code()
|
|
||||||
except OSError:
|
|
||||||
return
|
|
||||||
Trainer._original_evaluation_loop = ( # pylint: disable=protected-access
|
|
||||||
evaluation_loop
|
|
||||||
)
|
|
||||||
evaluation_loop, _ = detab_code(evaluation_loop)
|
|
||||||
if ORIGINAL_TRAINER_CODE not in evaluation_loop:
|
|
||||||
return
|
|
||||||
|
|
||||||
evaluation_loop = evaluation_loop.replace(
|
|
||||||
ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE
|
|
||||||
)
|
|
||||||
evaluation_loop = evaluation_loop.replace(
|
|
||||||
"def evaluation_loop(",
|
|
||||||
"def _fixed_evaluation_loop(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.trainer
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.trainer):
|
|
||||||
if item in evaluation_loop:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
|
||||||
"from transformers.trainer import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(evaluation_loop, globals()) # pylint: disable=exec-used # nosec B102
|
|
||||||
LOG.info("patching _inner_training_loop for fsdp optimizer save")
|
|
||||||
Trainer.evaluation_loop = ( # pylint: disable=protected-access
|
|
||||||
_fixed_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
|
||||||
)
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
"""
|
|
||||||
Monkey patch to fix transformers.modeling_flash_attention_utils.
|
|
||||||
|
|
||||||
see https://github.com/huggingface/transformers/pull/39653/files
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_from_posids(query, key, value, position_ids):
|
|
||||||
"""
|
|
||||||
This function returns necessary arguments to call `flash_attn_varlen_func`.
|
|
||||||
All three query, key, value states will be flattened.
|
|
||||||
Cumulative lengths of each examples in the batch will be extracted from position_ids.
|
|
||||||
NOTE: ideally cumulative lengths should be prepared at the data collator stage
|
|
||||||
Arguments:
|
|
||||||
query (`torch.Tensor`):
|
|
||||||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
|
|
||||||
key (`torch.Tensor`):
|
|
||||||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
|
||||||
value (`torch.Tensor`):
|
|
||||||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
|
|
||||||
position_ids (`torch.Tensor`):
|
|
||||||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
|
|
||||||
Return:
|
|
||||||
query (`torch.Tensor`):
|
|
||||||
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
|
|
||||||
key (`torch.Tensor`):
|
|
||||||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
|
||||||
value (`torch.Tensor`):
|
|
||||||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
|
|
||||||
indices_q (`torch.Tensor`):
|
|
||||||
The indices of non-masked tokens from the flattened input target sequence.
|
|
||||||
(cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
|
|
||||||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
|
|
||||||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
|
|
||||||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
|
|
||||||
"""
|
|
||||||
query = query.contiguous().view(-1, query.size(-2), query.size(-1))
|
|
||||||
key = key.contiguous().view(-1, key.size(-2), key.size(-1))
|
|
||||||
value = value.contiguous().view(-1, value.size(-2), value.size(-1))
|
|
||||||
|
|
||||||
position_ids = position_ids.flatten()
|
|
||||||
indices_q = torch.arange(
|
|
||||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
cu_seq_lens = torch.cat(
|
|
||||||
(
|
|
||||||
indices_q[position_ids == 0],
|
|
||||||
torch.tensor(
|
|
||||||
position_ids.size(), device=position_ids.device, dtype=torch.int32
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# NOTE: With torch compile, this will cause a graph break if you don't set
|
|
||||||
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
|
|
||||||
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
|
|
||||||
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
|
|
||||||
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
|
|
||||||
# https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
|
|
||||||
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
|
|
||||||
# for some models (e.g. qwen2-vl).
|
|
||||||
max_length = cu_seq_lens.diff().max().item()
|
|
||||||
return (
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
indices_q,
|
|
||||||
(cu_seq_lens, cu_seq_lens),
|
|
||||||
(max_length, max_length),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_from_posids():
|
|
||||||
import transformers.modeling_flash_attention_utils
|
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils._prepare_from_posids = ( # pylint: disable=protected-access
|
|
||||||
_prepare_from_posids
|
|
||||||
)
|
|
||||||
setattr(
|
|
||||||
sys.modules["transformers.modeling_flash_attention_utils"],
|
|
||||||
"_prepare_from_posids",
|
|
||||||
_prepare_from_posids,
|
|
||||||
)
|
|
||||||
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
165
src/axolotl/monkeypatch/transformers/trainer_loss_calc.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Module for patching transformers Trainer loss calculation to use nanmean.
|
||||||
|
|
||||||
|
This is needed for context parallelism since chunks of the input sequences may be fully
|
||||||
|
masked and return NaNs in the loss calculation.
|
||||||
|
|
||||||
|
Also includes a patch for FSDP2 + torch.compile. We need to bundle this together with
|
||||||
|
the other evaluation_loop patch because we can't patch the same code twice without
|
||||||
|
raising an OSError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
ORIGINAL_EVAL_CODE = {
|
||||||
|
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()',
|
||||||
|
"array": 'metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()',
|
||||||
|
}
|
||||||
|
PATCHED_EVAL_CODE = {
|
||||||
|
"list": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(np.concatenate(all_losses)).item()',
|
||||||
|
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
|
||||||
|
}
|
||||||
|
|
||||||
|
ORIGINAL_FSDP2_CODE = """
|
||||||
|
model.eval()
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_FSDP2_CODE = """
|
||||||
|
if hasattr(model, "eval") and callable(model.eval):
|
||||||
|
self.model.eval()
|
||||||
|
"""
|
||||||
|
|
||||||
|
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
|
||||||
|
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
|
||||||
|
|
||||||
|
|
||||||
|
def check_evaluation_loop_is_patchable() -> bool:
|
||||||
|
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||||
|
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
|
||||||
|
|
||||||
|
|
||||||
|
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
|
||||||
|
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||||
|
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||||
|
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def patch_evaluation_loop(patch_fsdp2: bool):
|
||||||
|
"""Patch the evaluation_loop method."""
|
||||||
|
# Check if already patched
|
||||||
|
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||||
|
LOG.info("Trainer.evaluation_loop already patched")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the patterns exist
|
||||||
|
try:
|
||||||
|
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer.evaluation = evaluation_loop_source
|
||||||
|
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
|
||||||
|
|
||||||
|
# Apply the nanmean patches
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
ORIGINAL_EVAL_CODE["list"], PATCHED_EVAL_CODE["list"]
|
||||||
|
)
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply FSDP2 eval guard patch if needed
|
||||||
|
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
|
||||||
|
)
|
||||||
|
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
|
||||||
|
|
||||||
|
# Rename the function to avoid conflicts
|
||||||
|
evaluation_loop_source = evaluation_loop_source.replace(
|
||||||
|
"def evaluation_loop(",
|
||||||
|
"def axolotl_evaluation_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the module for necessary imports
|
||||||
|
module_name = Trainer.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Import necessary items from the module
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in evaluation_loop_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
# Execute the imports and patched method
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(evaluation_loop_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||||
|
Trainer.evaluation_loop = (
|
||||||
|
axolotl_evaluation_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_maybe_log_save_evaluate_is_patchable() -> bool:
|
||||||
|
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||||
|
return ORIGINAL_MAYBE_CODE in maybe_log_source
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def patch_maybe_log_save_evaluate():
|
||||||
|
"""Patch the _maybe_log_save_evaluate method."""
|
||||||
|
# Check if already patched
|
||||||
|
if hasattr(Trainer, "_original_maybe_log_save_evaluate"):
|
||||||
|
LOG.info("Trainer._maybe_log_save_evaluate already patched")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if the patterns exist
|
||||||
|
try:
|
||||||
|
maybe_log_source = inspect.getsource(Trainer._maybe_log_save_evaluate)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_maybe_log_save_evaluate = maybe_log_source
|
||||||
|
maybe_log_source, _ = detab_code(maybe_log_source)
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
maybe_log_source = maybe_log_source.replace(ORIGINAL_MAYBE_CODE, PATCHED_MAYBE_CODE)
|
||||||
|
|
||||||
|
# Rename the function to avoid conflicts
|
||||||
|
maybe_log_source = maybe_log_source.replace(
|
||||||
|
"def _maybe_log_save_evaluate(",
|
||||||
|
"def axolotl_maybe_log_save_evaluate(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the module for necessary imports
|
||||||
|
module_name = Trainer.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Import necessary items from the module
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in maybe_log_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
# Execute the imports and patched method
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(maybe_log_source, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
|
||||||
|
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||||
|
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate # pylint: disable=undefined-variable # noqa: F821
|
||||||
@@ -6,7 +6,7 @@ from typing import Optional
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from PIL.Image import Resampling
|
from PIL.Image import Resampling
|
||||||
from torch import Tensor, zeros_like
|
from torch import Tensor, zeros_like
|
||||||
from transformers import ProcessorMixin, VoxtralProcessor
|
from transformers import ProcessorMixin, SmolVLMProcessor, VoxtralProcessor
|
||||||
from transformers.image_utils import load_image
|
from transformers.image_utils import load_image
|
||||||
|
|
||||||
from axolotl.utils.dict import remove_none_values
|
from axolotl.utils.dict import remove_none_values
|
||||||
@@ -138,7 +138,7 @@ class ProcessingStrategy:
|
|||||||
image_key = key
|
image_key = key
|
||||||
break
|
break
|
||||||
|
|
||||||
# if the image key exists, add the image to the first message
|
# if the image key exists, add the image to the first user message
|
||||||
if image_key is not None and processed_example[image_key] is not None:
|
if image_key is not None and processed_example[image_key] is not None:
|
||||||
# TODO: check if it's normal to be single image only for common datasets
|
# TODO: check if it's normal to be single image only for common datasets
|
||||||
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
# From observation, it's usually a list of single image but some datasets may have several columns for images
|
||||||
@@ -179,26 +179,34 @@ class ProcessingStrategy:
|
|||||||
|
|
||||||
# Look for any image type in the first message
|
# Look for any image type in the first message
|
||||||
# some dataset have an {type: "image"} in the first message
|
# some dataset have an {type: "image"} in the first message
|
||||||
|
msg_ind_to_add = None
|
||||||
ind_to_add = None
|
ind_to_add = None
|
||||||
|
first_user_idx = None
|
||||||
|
|
||||||
for i, content in enumerate(
|
for msg_idx, msg_content in enumerate(processed_example["messages"]):
|
||||||
processed_example["messages"][0]["content"]
|
if first_user_idx is None and msg_content["role"] == "user":
|
||||||
):
|
first_user_idx = msg_idx
|
||||||
# Usually datasets created with image columns, don't have it in the messages itself
|
for i, content in enumerate(
|
||||||
if content["type"] == "image" and all(
|
processed_example["messages"][msg_idx]["content"]
|
||||||
k not in content for k in ["image", "url", "path", "base64"]
|
|
||||||
):
|
):
|
||||||
ind_to_add = i
|
# Usually datasets created with image columns, don't have it in the messages itself
|
||||||
break
|
if content["type"] == "image" and all(
|
||||||
|
k not in content for k in ["image", "url", "path", "base64"]
|
||||||
|
):
|
||||||
|
msg_ind_to_add = msg_idx
|
||||||
|
ind_to_add = i
|
||||||
|
break
|
||||||
|
|
||||||
# If an image type is found, add the image to that index
|
# If an image type is found, add the image to that index
|
||||||
if ind_to_add is not None:
|
if ind_to_add is not None and msg_ind_to_add is not None:
|
||||||
processed_example["messages"][0]["content"][ind_to_add][
|
processed_example["messages"][msg_ind_to_add]["content"][
|
||||||
"image"
|
ind_to_add
|
||||||
] = image_value
|
]["image"] = image_value
|
||||||
else:
|
else:
|
||||||
# if no image type is found, add it to end of the first message
|
# if no image type is found, add it to end of the first user message
|
||||||
processed_example["messages"][0]["content"].append(
|
if first_user_idx is None:
|
||||||
|
first_user_idx = 0
|
||||||
|
processed_example["messages"][first_user_idx]["content"].append(
|
||||||
{
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": image_value,
|
"image": image_value,
|
||||||
@@ -395,6 +403,24 @@ class VoxtralProcessingStrategy(ProcessingStrategy):
|
|||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLM2ProcessingStrategy(ProcessingStrategy):
|
||||||
|
"""Processing Strategy class for SmolVLM2"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: ProcessorMixin,
|
||||||
|
chat_template: Optional[str] = None,
|
||||||
|
image_size: int | tuple[int, int] | None = None,
|
||||||
|
image_resize_algorithm: Resampling | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
|
||||||
|
self.image_token = "<image>" # nosec
|
||||||
|
|
||||||
|
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
||||||
|
processor.tokenizer.additional_special_tokens.index(self.image_token)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_processing_strategy(
|
def get_processing_strategy(
|
||||||
processor: ProcessorMixin,
|
processor: ProcessorMixin,
|
||||||
chat_template,
|
chat_template,
|
||||||
@@ -402,32 +428,43 @@ def get_processing_strategy(
|
|||||||
image_size: int | tuple[int, int] | None = None,
|
image_size: int | tuple[int, int] | None = None,
|
||||||
image_resize_algorithm: Resampling | None = None,
|
image_resize_algorithm: Resampling | None = None,
|
||||||
):
|
):
|
||||||
|
processing_kwargs = {
|
||||||
|
"processor": processor,
|
||||||
|
"chat_template": chat_template,
|
||||||
|
"image_size": image_size,
|
||||||
|
"image_resize_algorithm": image_resize_algorithm,
|
||||||
|
}
|
||||||
|
|
||||||
|
if chat_template_type in [None, "tokenizer_default"] and hasattr(
|
||||||
|
processor.tokenizer, "chat_template"
|
||||||
|
):
|
||||||
|
processing_kwargs["chat_template"] = processor.tokenizer.chat_template
|
||||||
|
|
||||||
if chat_template_type == "qwen2_vl":
|
if chat_template_type == "qwen2_vl":
|
||||||
return Qwen2VLProcessingStrategy(
|
return Qwen2VLProcessingStrategy(
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
**processing_kwargs,
|
||||||
)
|
)
|
||||||
if chat_template_type == "gemma3":
|
if chat_template_type == "gemma3":
|
||||||
return Gemma3ProcessingStrategy(
|
return Gemma3ProcessingStrategy(
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
**processing_kwargs,
|
||||||
)
|
)
|
||||||
if chat_template_type == "gemma3n":
|
if chat_template_type == "gemma3n":
|
||||||
return Gemma3nProcessingStrategy(
|
return Gemma3nProcessingStrategy(
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
**processing_kwargs,
|
||||||
)
|
|
||||||
if chat_template_type in [
|
|
||||||
"llama3_2_vision",
|
|
||||||
"llama4",
|
|
||||||
"llava",
|
|
||||||
"mistral_v7_tekken",
|
|
||||||
"pixtral",
|
|
||||||
]:
|
|
||||||
return ProcessingStrategy(
|
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(processor, VoxtralProcessor):
|
if isinstance(processor, VoxtralProcessor):
|
||||||
return VoxtralProcessingStrategy(
|
return VoxtralProcessingStrategy(
|
||||||
processor, chat_template, image_size, image_resize_algorithm
|
**processing_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unsupported chat template type: {chat_template_type}")
|
if isinstance(processor, SmolVLMProcessor):
|
||||||
|
return SmolVLM2ProcessingStrategy(
|
||||||
|
**processing_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# llama3_2_vision, llama4, llava
|
||||||
|
# mistral_v7_tekken, pixtral, lfm2vl
|
||||||
|
return ProcessingStrategy(
|
||||||
|
**processing_kwargs,
|
||||||
|
)
|
||||||
|
|||||||
@@ -41,7 +41,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
field_messages: str = "messages",
|
field_messages: str = "messages",
|
||||||
field_system: str = "system",
|
field_system: str = "system",
|
||||||
field_tools: str = "tools",
|
field_tools: str = "tools",
|
||||||
|
field_thinking: str = "reasoning_content",
|
||||||
roles: dict[str, list[str]] | None = None,
|
roles: dict[str, list[str]] | None = None,
|
||||||
|
template_thinking_key: str | None = "reasoning_content",
|
||||||
chat_template_kwargs: dict[str, Any] | None = None,
|
chat_template_kwargs: dict[str, Any] | None = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
@@ -50,8 +52,9 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
message_property_mappings = {
|
message_property_mappings = {
|
||||||
"role": "role",
|
"role": "role",
|
||||||
"content": "content",
|
"content": "content",
|
||||||
"reasoning_content": "reasoning_content",
|
|
||||||
}
|
}
|
||||||
|
if template_thinking_key and field_thinking:
|
||||||
|
message_property_mappings[template_thinking_key] = field_thinking
|
||||||
|
|
||||||
if roles:
|
if roles:
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
@@ -74,10 +77,12 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.field_messages = field_messages
|
self.field_messages = field_messages
|
||||||
self.field_system = field_system
|
self.field_system = field_system
|
||||||
self.field_tools = field_tools
|
self.field_tools = field_tools
|
||||||
|
self.field_thinking = field_thinking
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor: ProcessorMixin | None = processor
|
self.processor: ProcessorMixin | None = processor
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.chat_template_kwargs = chat_template_kwargs or {}
|
self.chat_template_kwargs = chat_template_kwargs or {}
|
||||||
|
self.template_thinking_key: str = template_thinking_key or "reasoning_content"
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
@@ -124,13 +129,21 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
images=images,
|
images=images,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
if hasattr(batch, "to_dict"):
|
||||||
|
batch = batch.to_dict()
|
||||||
|
else:
|
||||||
|
batch = dict(batch)
|
||||||
|
|
||||||
# workaround since processor works in batches instead of single examples
|
# workaround since processor works in batches instead of single examples
|
||||||
|
out = {}
|
||||||
for k, val in batch.items():
|
for k, val in batch.items():
|
||||||
if k in ["pixel_values"]:
|
if hasattr(val, "tolist"):
|
||||||
batch[k] = val.tolist()
|
out[k] = (
|
||||||
|
val.tolist() if k == "pixel_values" else val.squeeze(0).tolist()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch[k] = val.squeeze().tolist()
|
out[k] = val
|
||||||
return batch
|
return out
|
||||||
|
|
||||||
return self.tokenizer.apply_chat_template(
|
return self.tokenizer.apply_chat_template(
|
||||||
conversation,
|
conversation,
|
||||||
@@ -428,10 +441,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
|
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
|
||||||
else:
|
else:
|
||||||
input_ids = tokenized_res["input_ids"]
|
input_ids = tokenized_res["input_ids"]
|
||||||
tokenized_prompt = tokenized_res
|
tokenized_prompt = dict(tokenized_res)
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
if not self.train_on_inputs:
|
||||||
user_prompt_len = len(prompt_ids)
|
if isinstance(prompt_ids, dict):
|
||||||
|
user_prompt_len = len(prompt_ids["input_ids"])
|
||||||
|
else:
|
||||||
|
user_prompt_len = len(prompt_ids)
|
||||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
||||||
else:
|
else:
|
||||||
labels = input_ids
|
labels = input_ids
|
||||||
@@ -742,7 +758,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
# get the thinking content
|
# get the thinking content
|
||||||
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
thinking_content = content[t_start_idx + len(tpair[0]) : t_end_idx]
|
||||||
transformed_message["reasoning_content"] = thinking_content.strip()
|
transformed_message[self.prompter.template_thinking_key] = (
|
||||||
|
thinking_content.strip()
|
||||||
|
)
|
||||||
|
|
||||||
# take remainder of the content
|
# take remainder of the content
|
||||||
# strip whitespace from beginning of the remainder (thinking tokens)
|
# strip whitespace from beginning of the remainder (thinking tokens)
|
||||||
@@ -953,6 +971,10 @@ class StrategyLoader:
|
|||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||||
|
"field_thinking": dataset_config.get("field_thinking", "reasoning_content"),
|
||||||
|
"template_thinking_key": dataset_config.get(
|
||||||
|
"template_thinking_key", "reasoning_content"
|
||||||
|
),
|
||||||
"roles": dataset_config.get("roles"),
|
"roles": dataset_config.get("roles"),
|
||||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||||
if not prompt:
|
if not prompt:
|
||||||
LOG.warning("Empty text requested for tokenization.")
|
LOG.warning_once("Empty text requested for tokenization.")
|
||||||
return empty
|
return empty
|
||||||
|
|
||||||
result = self.tokenizer(
|
result = self.tokenizer(
|
||||||
|
|||||||
@@ -4,11 +4,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import weakref
|
import weakref
|
||||||
|
from collections import OrderedDict
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
@@ -38,6 +41,7 @@ from axolotl.utils.distributed import cleanup_distributed
|
|||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -46,7 +50,7 @@ except ImportError:
|
|||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -124,32 +128,6 @@ def setup_reference_model(
|
|||||||
return model_ref
|
return model_ref
|
||||||
|
|
||||||
|
|
||||||
def determine_resume_checkpoint(cfg: DictDefault) -> str | None:
|
|
||||||
"""
|
|
||||||
Determine the checkpoint to resume from based on configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the checkpoint to resume from, or `None` if not resuming.
|
|
||||||
"""
|
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
|
||||||
possible_checkpoints = [
|
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
|
||||||
]
|
|
||||||
if len(possible_checkpoints) > 0:
|
|
||||||
sorted_paths = sorted(
|
|
||||||
possible_checkpoints,
|
|
||||||
key=lambda path: int(path.split("-")[-1]),
|
|
||||||
)
|
|
||||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
|
||||||
LOG.info(
|
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
|
||||||
)
|
|
||||||
return cfg.resume_from_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def setup_signal_handler(
|
def setup_signal_handler(
|
||||||
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
cfg: DictDefault, model: PreTrainedModel, safe_serialization: bool
|
||||||
):
|
):
|
||||||
@@ -218,6 +196,7 @@ def execute_training(
|
|||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
gather_outputs=cfg.rl is RLType.GRPO,
|
gather_outputs=cfg.rl is RLType.GRPO,
|
||||||
|
device_mesh=trainer.accelerator.torch_device_mesh,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -274,19 +253,56 @@ def save_trained_model(
|
|||||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||||
return
|
return
|
||||||
|
|
||||||
if trainer.is_fsdp_enabled:
|
if trainer.is_fsdp_enabled or cfg.fsdp_config:
|
||||||
if cfg.fsdp_config or cfg.fsdp:
|
if cfg.fsdp_config or cfg.fsdp:
|
||||||
if cfg.fsdp_config.final_state_dict_type:
|
if cfg.fsdp_config.final_state_dict_type:
|
||||||
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
state_dict_type = cfg.fsdp_config.final_state_dict_type
|
||||||
else:
|
else:
|
||||||
state_dict_type = cfg.fsdp_config.state_dict_type
|
state_dict_type = cfg.fsdp_config.state_dict_type
|
||||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||||
trainer.save_model(cfg.output_dir)
|
trainer.save_model(cfg.output_dir) # only handles FULL_STATE_DICT
|
||||||
if state_dict_type == "SHARDED_STATE_DICT":
|
if state_dict_type == "SHARDED_STATE_DICT":
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"The final model was saved with a sharded state dict. Please ensure you merge "
|
"The final model was saved with a sharded state dict. Please ensure you merge "
|
||||||
"the sharded weights with `merge-sharded-fsdp-weights`."
|
"the sharded weights with `merge-sharded-fsdp-weights`."
|
||||||
)
|
)
|
||||||
|
checkpoint_dir = determine_last_checkpoint(cfg, update=False)
|
||||||
|
if (
|
||||||
|
not (Path(cfg.output_dir) / "model.safetensors.index.json").exists()
|
||||||
|
and checkpoint_dir
|
||||||
|
):
|
||||||
|
# import here to prevent circular import
|
||||||
|
from axolotl.cli.merge_sharded_fsdp_weights import merge_fsdp_weights
|
||||||
|
|
||||||
|
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
|
||||||
|
merged_path = str(Path(cfg.output_dir) / "merged")
|
||||||
|
merge_fsdp_weights(
|
||||||
|
checkpoint_dir=str(fsdp_dir),
|
||||||
|
output_path=merged_path,
|
||||||
|
safe_serialization=True,
|
||||||
|
)
|
||||||
|
trainer.accelerator.wait_for_everyone()
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
# move all files in merged_path to cfg.output_dir
|
||||||
|
for merged_file in Path(merged_path).iterdir():
|
||||||
|
shutil.move(str(merged_file), cfg.output_dir)
|
||||||
|
shutil.rmtree(merged_path) # remove what should be an empty dir
|
||||||
|
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207
|
||||||
|
# cleanup the FSDP prefix in the model config.json
|
||||||
|
if trainer.accelerator.is_main_process:
|
||||||
|
with open(
|
||||||
|
Path(cfg.output_dir) / "config.json", "r", encoding="utf-8"
|
||||||
|
) as config_file_io:
|
||||||
|
# read the model config as an OrderedDict
|
||||||
|
config = json.load(config_file_io, object_pairs_hook=OrderedDict)
|
||||||
|
config["architectures"] = [
|
||||||
|
name.lstrip("FSDP") for name in config["architectures"]
|
||||||
|
]
|
||||||
|
# write the updated model config back
|
||||||
|
with open(
|
||||||
|
os.path.join(cfg.output_dir, "config.json"), "w", encoding="utf-8"
|
||||||
|
) as config_file_io:
|
||||||
|
json.dump(config, config_file_io, indent=2)
|
||||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
@@ -563,9 +579,13 @@ def train(
|
|||||||
setup_model_card(cfg)
|
setup_model_card(cfg)
|
||||||
|
|
||||||
# Execute the training
|
# Execute the training
|
||||||
resume_from_checkpoint = determine_resume_checkpoint(cfg)
|
resume_from_checkpoint = determine_last_checkpoint(cfg)
|
||||||
execute_training(cfg, trainer, resume_from_checkpoint)
|
execute_training(cfg, trainer, resume_from_checkpoint)
|
||||||
|
|
||||||
|
# clear cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Save the trained model and cleanup
|
# Save the trained model and cleanup
|
||||||
save_trained_model(cfg, trainer, model, safe_serialization)
|
save_trained_model(cfg, trainer, model, safe_serialization)
|
||||||
create_model_card(cfg, trainer)
|
create_model_card(cfg, trainer)
|
||||||
|
|||||||
@@ -161,6 +161,8 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
Collator for multipack specific to the using the BatchSampler
|
Collator for multipack specific to the using the BatchSampler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
squash_position_ids: bool = False
|
||||||
|
|
||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if not isinstance(features[0], list):
|
if not isinstance(features[0], list):
|
||||||
features: List[List[dict]] = [features]
|
features: List[List[dict]] = [features]
|
||||||
@@ -176,6 +178,15 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if feature in item
|
if feature in item
|
||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
elif feature == "position_ids" and self.squash_position_ids:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
|
]
|
||||||
|
# concatenate, get total length and create arange of new total position ids
|
||||||
|
position_ids = np.concatenate(arrays)
|
||||||
|
total_length = position_ids.shape[0]
|
||||||
|
position_ids = np.arange(total_length)
|
||||||
|
out_features[i][feature] = position_ids
|
||||||
else:
|
else:
|
||||||
arrays = [
|
arrays = [
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
np.array(item[feature]) for item in features_ if feature in item
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user