Compare commits
19 Commits
update-vll
...
torch_tens
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41664c7c4c | ||
|
|
9a8073e73d | ||
|
|
7fb8441e0e | ||
|
|
4dc5910e1c | ||
|
|
fb7bc9250d | ||
|
|
d6e4a611e5 | ||
|
|
eb662557a7 | ||
|
|
03b2a113fe | ||
|
|
9b95a625ab | ||
|
|
c370d0795c | ||
|
|
76aeb16156 | ||
|
|
7c5ea0010f | ||
|
|
c6d69d5c1b | ||
|
|
4ff96a2526 | ||
|
|
89e99eaaa7 | ||
|
|
6ed501f6dc | ||
|
|
8c6a6ea6eb | ||
|
|
78bff4925e | ||
|
|
b237c8a3f3 |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -29,11 +29,11 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "124"
|
- cuda: "126"
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
@@ -43,7 +43,7 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-base"
|
dockerfile: "Dockerfile-base"
|
||||||
- cuda: "126"
|
- cuda: "126"
|
||||||
|
|||||||
20
.github/workflows/main.yml
vendored
20
.github/workflows/main.yml
vendored
@@ -15,15 +15,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
@@ -82,17 +82,17 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
7
.github/workflows/multi-gpu-e2e.yml
vendored
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -33,13 +33,6 @@ jobs:
|
|||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
num_gpus: 2
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
11
.github/workflows/nightlies.yml
vendored
11
.github/workflows/nightlies.yml
vendored
@@ -12,11 +12,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -68,10 +63,10 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
17
.github/workflows/preview-docs.yml
vendored
17
.github/workflows/preview-docs.yml
vendored
@@ -28,6 +28,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.pull_request.head.sha }}
|
||||||
|
|
||||||
- name: Set up Quarto
|
- name: Set up Quarto
|
||||||
uses: quarto-dev/quarto-actions/setup@v2
|
uses: quarto-dev/quarto-actions/setup@v2
|
||||||
@@ -50,10 +52,11 @@ jobs:
|
|||||||
|
|
||||||
- name: Netlify Publish
|
- name: Netlify Publish
|
||||||
uses: nwtgck/actions-netlify@v3.0
|
uses: nwtgck/actions-netlify@v3.0
|
||||||
|
id: netlify
|
||||||
with:
|
with:
|
||||||
publish-dir: './_site'
|
publish-dir: './_site'
|
||||||
enable-pull-request-comment: true
|
enable-pull-request-comment: false
|
||||||
enable-github-deployment: true
|
enable-github-deployment: false
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
deploy-message: "Deployed On Netlify"
|
deploy-message: "Deployed On Netlify"
|
||||||
github-deployment-environment: 'preview'
|
github-deployment-environment: 'preview'
|
||||||
@@ -61,3 +64,13 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
|
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
|
||||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||||
|
|
||||||
|
- name: Update PR with preview link
|
||||||
|
if: ${{ steps.netlify.outcome == 'success' }}
|
||||||
|
uses: marocchino/sticky-pull-request-comment@v2
|
||||||
|
with:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
message: |
|
||||||
|
📖 **Documentation Preview**: ${{ steps.netlify.outputs.deploy-url }}
|
||||||
|
|
||||||
|
Deployed on Netlify from commit ${{ github.event.pull_request.head.sha }}
|
||||||
|
|||||||
8
.github/workflows/tests-nightly.yml
vendored
8
.github/workflows/tests-nightly.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
pytorch_version: ["2.6.0", "2.7.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -80,9 +80,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v --durations=10 tests/patched/
|
||||||
pytest -v tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
30
.github/workflows/tests.yml
vendored
30
.github/workflows/tests.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -102,9 +102,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
|
||||||
pytest -v tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
pytest -v tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v5
|
uses: codecov/codecov-action@v5
|
||||||
@@ -125,7 +125,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -175,9 +175,9 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v --durations=10 tests/patched/
|
||||||
pytest -v tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -198,7 +198,7 @@ jobs:
|
|||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.7.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
@@ -252,18 +252,6 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: llmcompressor
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 126
|
|
||||||
cuda_version: 12.6.3
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.7.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
|
|||||||
@@ -97,7 +97,7 @@
|
|||||||
# # 'no_input_format' cannot include {input}
|
# # 'no_input_format' cannot include {input}
|
||||||
# no_input_format: "{instruction} "
|
# no_input_format: "{instruction} "
|
||||||
|
|
||||||
# # For `completion` datsets only, uses the provided field instead of `text` column
|
# # For `completion` datasets only, uses the provided field instead of `text` column
|
||||||
# field:
|
# field:
|
||||||
|
|
||||||
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# # Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ Features:
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python 3.11
|
||||||
- PyTorch ≥2.5.1
|
- PyTorch ≥2.6.0
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
|||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||||
"CUDA": os.environ.get("CUDA", "124"),
|
"CUDA": os.environ.get("CUDA", "126"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template(dockerfile)
|
|||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.6.0"),
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu126-2.6.0"),
|
||||||
"CUDA": os.environ.get("CUDA", "124"),
|
"CUDA": os.environ.get("CUDA", "126"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
|||||||
"role": "assistant", // call the function via assistant
|
"role": "assistant", // call the function via assistant
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
|
"id": "...", // required only for mistral
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "...",
|
"name": "...",
|
||||||
@@ -199,6 +200,7 @@ Instead of passing `tools` via the system prompt, an alternative method would be
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
|
"tool_call_id": "...", // required only for mistral
|
||||||
"name": "...",
|
"name": "...",
|
||||||
"content": "..."
|
"content": "..."
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -34,9 +34,9 @@ Tags examples:
|
|||||||
|
|
||||||
- `main-base-py3.11-cu128-2.7.1`
|
- `main-base-py3.11-cu128-2.7.1`
|
||||||
- `main-base-py3.11-cu126-2.7.1`
|
- `main-base-py3.11-cu126-2.7.1`
|
||||||
|
- `main-base-py3.11-cu126-2.7.0`
|
||||||
- `main-base-py3.11-cu126-2.6.0`
|
- `main-base-py3.11-cu126-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.6.0`
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.5.1`
|
|
||||||
|
|
||||||
## Main
|
## Main
|
||||||
|
|
||||||
@@ -76,12 +76,12 @@ Tags examples:
|
|||||||
|
|
||||||
- `main-py3.11-cu128-2.7.1`
|
- `main-py3.11-cu128-2.7.1`
|
||||||
- `main-py3.11-cu126-2.7.1`
|
- `main-py3.11-cu126-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.7.0`
|
||||||
- `main-py3.11-cu126-2.6.0`
|
- `main-py3.11-cu126-2.6.0`
|
||||||
- `main-py3.11-cu124-2.6.0`
|
- `main-py3.11-cu124-2.6.0`
|
||||||
- `main-py3.11-cu124-2.5.1`
|
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu124-2.5.1`
|
- `main-20250303-py3.11-cu126-2.6.0`
|
||||||
- `0.10.1`
|
- `0.10.1`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.11
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.5.1
|
- PyTorch ≥2.6.0
|
||||||
|
|
||||||
## Installation Methods {#sec-installation-methods}
|
## Installation Methods {#sec-installation-methods}
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,6 @@ Axolotl supports several methods for multi-GPU training:
|
|||||||
|
|
||||||
## DeepSpeed {#sec-deepspeed}
|
## DeepSpeed {#sec-deepspeed}
|
||||||
|
|
||||||
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
|
|
||||||
|
|
||||||
### Configuration {#sec-deepspeed-config}
|
### Configuration {#sec-deepspeed-config}
|
||||||
|
|
||||||
Add to your YAML config:
|
Add to your YAML config:
|
||||||
@@ -32,7 +30,6 @@ Add to your YAML config:
|
|||||||
```{.yaml}
|
```{.yaml}
|
||||||
deepspeed: deepspeed_configs/zero1.json
|
deepspeed: deepspeed_configs/zero1.json
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage {#sec-deepspeed-usage}
|
### Usage {#sec-deepspeed-usage}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
@@ -66,9 +63,75 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## FSDP {#sec-fsdp}
|
::: {.callout-tip}
|
||||||
|
|
||||||
### Basic FSDP Configuration {#sec-fsdp-config}
|
Using ZeRO Stage 3 with Single-GPU training
|
||||||
|
|
||||||
|
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
|
||||||
|
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
|
||||||
|
|
||||||
|
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
|
||||||
|
also follow the config field mapping below to update field names.
|
||||||
|
|
||||||
|
#### Config mapping
|
||||||
|
|
||||||
|
FSDP1 | FSDP2
|
||||||
|
-------- | --------
|
||||||
|
fsdp_sharding_strategy | reshard_after_forward
|
||||||
|
fsdp_backward_prefetch_policy | **REMOVED**
|
||||||
|
fsdp_backward_prefetch | **REMOVED**
|
||||||
|
fsdp_forward_prefetch | **REMOVED**
|
||||||
|
fsdp_sync_module_states | **REMOVED**
|
||||||
|
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
|
||||||
|
fsdp_state_dict_type | state_dict_type
|
||||||
|
fsdp_use_orig_params | **REMOVED**
|
||||||
|
|
||||||
|
|
||||||
|
For example, if you were using the following FSDP1 config:
|
||||||
|
|
||||||
|
```{.yaml}
|
||||||
|
fsdp_version: 1
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
```
|
||||||
|
|
||||||
|
You can migrate to the following FSDP2 config:
|
||||||
|
|
||||||
|
```{.yaml}
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_config:
|
||||||
|
offload_params: false
|
||||||
|
cpu_ram_efficient_loading: true
|
||||||
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||||
|
state_dict_type: FULL_STATE_DICT
|
||||||
|
reshard_after_forward: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### FSDP1 (deprecated) {#sec-fsdp-config}
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
Using `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
```{.yaml}
|
```{.yaml}
|
||||||
fsdp:
|
fsdp:
|
||||||
@@ -80,6 +143,7 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Sequence parallelism {#sec-sequence-parallelism}
|
## Sequence parallelism {#sec-sequence-parallelism}
|
||||||
|
|
||||||
We support sequence parallelism (SP) via the
|
We support sequence parallelism (SP) via the
|
||||||
|
|||||||
@@ -40,13 +40,13 @@ use_cpu: false
|
|||||||
|
|
||||||
Configure your model to use FSDP in the Axolotl yaml. For example:
|
Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||||
```yaml
|
```yaml
|
||||||
fsdp:
|
fsdp_version: 2
|
||||||
- full_shard
|
|
||||||
- auto_wrap
|
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
fsdp_offload_params: true
|
offload_params: true
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
state_dict_type: FULL_STATE_DICT
|
||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
reshard_after_forward: true
|
||||||
```
|
```
|
||||||
|
|
||||||
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ feedback. Various methods include, but not limited to:
|
|||||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
|
|
||||||
|
|
||||||
|
|
||||||
## RLHF using Axolotl
|
## RLHF using Axolotl
|
||||||
@@ -275,8 +274,7 @@ rl: dpo
|
|||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
split: train
|
split: train
|
||||||
type: user_defined.default
|
type:
|
||||||
|
|
||||||
field_prompt: "prompt"
|
field_prompt: "prompt"
|
||||||
field_system: "system"
|
field_system: "system"
|
||||||
field_chosen: "chosen"
|
field_chosen: "chosen"
|
||||||
@@ -476,8 +474,7 @@ rl: kto
|
|||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
split: train
|
split: train
|
||||||
type: user_defined.default
|
type:
|
||||||
|
|
||||||
field_prompt: "prompt"
|
field_prompt: "prompt"
|
||||||
field_system: "system"
|
field_system: "system"
|
||||||
field_completion: "completion"
|
field_completion: "completion"
|
||||||
|
|||||||
5
examples/archived/README.md
Normal file
5
examples/archived/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Archived Examples
|
||||||
|
|
||||||
|
This directory contains examples that are no longer maintained and may no longer be functional.
|
||||||
|
|
||||||
|
We keep them around for archival purposes in case they are useful to others.
|
||||||
70
examples/devstral/README.md
Normal file
70
examples/devstral/README.md
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# Finetune Devstral with Axolotl
|
||||||
|
|
||||||
|
Devstral Small is a 24B parameter opensource model from MistralAI found on HuggingFace [Devstral-Small-2505](https://huggingface.co/mistralai/Devstral-Small-2505) and [Devstral-Small-2507](https://huggingface.co/mistralai/Devstral-Small-2507). `Devstral-Small-2507` is the latest version of the model and has [function calling](https://mistralai.github.io/mistral-common/usage/tools/) support.
|
||||||
|
|
||||||
|
This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||||
|
|
||||||
|
The model was fine-tuned ontop of [Mistral-Small-3.1](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503) without the vision layer and has a context of up to 128k tokens.
|
||||||
|
|
||||||
|
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||||
|
|
||||||
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the finetuning example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl train examples/devstral/devstral-small-qlora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
This config uses about 21GB VRAM.
|
||||||
|
|
||||||
|
Let us know how it goes. Happy finetuning! 🚀
|
||||||
|
|
||||||
|
### TIPS
|
||||||
|
|
||||||
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
- Learn how to use function calling with Axolotl at [docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use).
|
||||||
|
|
||||||
|
## Optimization Guides
|
||||||
|
|
||||||
|
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||||
|
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||||
|
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||||
|
- [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy)
|
||||||
|
- [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels)
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||||
|
|
||||||
|
In addition, we do not support overriding tokens yet.
|
||||||
|
|
||||||
|
## Related Resources
|
||||||
|
|
||||||
|
- [MistralAI Devstral Blog](https://mistral.ai/news/devstral)
|
||||||
|
- [MistralAI Devstral 1.1 Blog](https://mistral.ai/news/devstral-2507)
|
||||||
|
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||||
|
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
- [Axolotl Website](https://axolotl.ai)
|
||||||
|
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||||
|
|
||||||
|
|
||||||
|
## Future Work
|
||||||
|
|
||||||
|
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||||
|
- Add parity to other tokenizer configs like overriding tokens.
|
||||||
64
examples/devstral/devstral-small-qlora.yml
Normal file
64
examples/devstral/devstral-small-qlora.yml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: mistralai/Devstral-Small-2507
|
||||||
|
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
# Enable to use mistral-common tokenizer
|
||||||
|
tokenizer_use_mistral_common: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
|
type: chat_template
|
||||||
|
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./outputs/qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_ratio: 0.05
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
weight_decay: 0.0
|
||||||
|
special_tokens:
|
||||||
7
examples/lfm2/README.md
Normal file
7
examples/lfm2/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Liquid Foundation Models 2
|
||||||
|
|
||||||
|
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
|
||||||
|
```
|
||||||
48
examples/lfm2/lfm2-350m-fft.yaml
Normal file
48
examples/lfm2/lfm2-350m-fft.yaml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
base_model: LiquidAI/LFM2-350M
|
||||||
|
|
||||||
|
chunked_cross_entropy: true
|
||||||
|
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
eot_tokens:
|
||||||
|
- "<|im_end|>"
|
||||||
|
datasets:
|
||||||
|
- path: mlabonne/FineTome-100k
|
||||||
|
type: chat_template
|
||||||
|
split: train[:20%]
|
||||||
|
field_messages: conversations
|
||||||
|
message_field_role: from
|
||||||
|
message_field_content: value
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 4
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-5
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 2
|
||||||
|
saves_per_epoch: 1
|
||||||
|
|
||||||
|
weight_decay: 0.0
|
||||||
@@ -18,16 +18,10 @@ git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
|||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Download the example config:
|
2. Run the finetuning example:
|
||||||
|
|
||||||
```bash
|
|
||||||
axolotl fetch examples
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Run the finetuning example:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||||
@@ -42,7 +36,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|
||||||
@@ -54,7 +48,7 @@ Let us know how it goes. Happy finetuning! 🚀
|
|||||||
|
|
||||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||||
|
|
||||||
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
In addition, we do not support overriding tokens yet.
|
||||||
|
|
||||||
## Related Resources
|
## Related Resources
|
||||||
|
|
||||||
|
|||||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
|||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.6.3
|
mistral-common==1.7.0
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"'
|
||||||
)
|
)
|
||||||
|
|||||||
7
setup.py
7
setup.py
@@ -66,8 +66,11 @@ def parse_requirements(extras_require_map):
|
|||||||
|
|
||||||
if (major, minor) >= (2, 7):
|
if (major, minor) >= (2, 7):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
if patch == 0:
|
||||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
_install_requires.append("xformers==0.0.30")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.31.post1")
|
||||||
|
extras_require_map["vllm"] = ["vllm>=0.9.0"]
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append(
|
_install_requires.append(
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.11.0.dev"
|
__version__ = "0.12.0.dev"
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
|
migrate_fsdp_config,
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
validate_config,
|
validate_config,
|
||||||
@@ -226,6 +227,7 @@ def load_cfg(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
migrate_fsdp_config(cfg)
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
prepare_opinionated_env(cfg)
|
prepare_opinionated_env(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|||||||
@@ -109,6 +109,13 @@ def ray_train_func(kwargs: dict):
|
|||||||
# initialize accelerator before model instantiation
|
# initialize accelerator before model instantiation
|
||||||
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# Register plugins in Ray workers
|
||||||
|
if cfg.get("plugins"):
|
||||||
|
from axolotl.cli.config import plugin_set_cfg, prepare_plugins
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
plugin_set_cfg(cfg)
|
||||||
|
|
||||||
kwargs["cfg"] = cfg
|
kwargs["cfg"] = cfg
|
||||||
|
|
||||||
do_train(**kwargs)
|
do_train(**kwargs)
|
||||||
|
|||||||
@@ -501,6 +501,10 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.reward_model or self.cfg.rl:
|
if self.cfg.reward_model or self.cfg.rl:
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
|
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||||
|
training_args_kwargs["fsdp_config"] = self.cfg.fsdp_config
|
||||||
|
training_args_kwargs["fsdp"] = self.cfg.fsdp if self.cfg.fsdp else True
|
||||||
|
|
||||||
self._configure_reporting(training_args_kwargs)
|
self._configure_reporting(training_args_kwargs)
|
||||||
self._configure_hub_parameters(training_args_kwargs)
|
self._configure_hub_parameters(training_args_kwargs)
|
||||||
self._configure_scheduler(training_args_kwargs)
|
self._configure_scheduler(training_args_kwargs)
|
||||||
|
|||||||
@@ -151,14 +151,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||||
total_num_steps
|
total_num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.fsdp:
|
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
|
||||||
if self.cfg.fsdp_config:
|
|
||||||
training_arguments_kwargs["fsdp_config"] = {
|
|
||||||
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.cfg.adapter == "qlora":
|
if self.cfg.adapter == "qlora":
|
||||||
training_arguments_kwargs["qlora"] = True
|
training_arguments_kwargs["qlora"] = True
|
||||||
|
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
callbacks=self.get_callbacks(),
|
callbacks=self.get_callbacks(),
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||||
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||||
@@ -218,21 +218,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer.add_callback(callback)
|
trainer.add_callback(callback)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
|
||||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
|
||||||
"""
|
|
||||||
HF Factory class for PPO Trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_callbacks(self):
|
|
||||||
callbacks = super().get_callbacks()
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
|
||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
|
||||||
# TODO: build PPOConfig
|
|
||||||
raise NotImplementedError("PPO trainer builder is not implemented yet.")
|
|
||||||
|
|||||||
@@ -14,5 +14,4 @@ from .trl import (
|
|||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
AxolotlPRMTrainer,
|
AxolotlPRMTrainer,
|
||||||
AxolotlRewardTrainer,
|
AxolotlRewardTrainer,
|
||||||
TRLPPOTrainer,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
"""Module for TRL PPO trainer"""
|
"""Module for TRL RL trainers"""
|
||||||
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from trl import (
|
from trl import (
|
||||||
CPOTrainer,
|
CPOTrainer,
|
||||||
KTOTrainer,
|
KTOTrainer,
|
||||||
ORPOTrainer,
|
ORPOTrainer,
|
||||||
PPOTrainer,
|
|
||||||
PRMTrainer,
|
PRMTrainer,
|
||||||
RewardTrainer,
|
RewardTrainer,
|
||||||
)
|
)
|
||||||
@@ -16,64 +13,6 @@ from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, Optimizer
|
|||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
|
||||||
"""Wrapper for TRL PPO trainer to handle customizations"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "ppo"]
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
reward_pipe,
|
|
||||||
resume_from_checkpoint=None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
generation_kwargs = {
|
|
||||||
"min_length": -1,
|
|
||||||
"top_k": 0.0,
|
|
||||||
"top_p": 1.0,
|
|
||||||
"do_sample": True,
|
|
||||||
"pad_token_id": self.tokenizer.eos_token_id,
|
|
||||||
"max_new_tokens": 32,
|
|
||||||
}
|
|
||||||
sent_kwargs = {
|
|
||||||
"return_all_scores": True,
|
|
||||||
"function_to_apply": "none",
|
|
||||||
"batch_size": 16,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, batch in tqdm(enumerate(self.dataloader)):
|
|
||||||
query_tensors = batch["input_ids"]
|
|
||||||
|
|
||||||
# generate model response
|
|
||||||
response_tensors, ref_response_tensors = self.generate(
|
|
||||||
query_tensors,
|
|
||||||
return_prompt=False,
|
|
||||||
generate_ref_response=True,
|
|
||||||
**generation_kwargs,
|
|
||||||
)
|
|
||||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
|
||||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
|
||||||
|
|
||||||
# Compute sentiment score
|
|
||||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
|
||||||
pipe_outputs = reward_pipe(texts, **sent_kwargs)
|
|
||||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
|
||||||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
|
||||||
ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
|
|
||||||
ref_rewards = [
|
|
||||||
torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
|
|
||||||
]
|
|
||||||
batch["ref_rewards"] = ref_rewards
|
|
||||||
|
|
||||||
# Run PPO step
|
|
||||||
stats = self.step(query_tensors, response_tensors, rewards)
|
|
||||||
self.log_stats(
|
|
||||||
stats,
|
|
||||||
batch,
|
|
||||||
rewards,
|
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(
|
class AxolotlORPOTrainer(
|
||||||
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
RngLoaderMixin, SchedulerMixin, OptimizerMixin, OptimizerInitMixin, ORPOTrainer
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -48,13 +48,6 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
|
||||||
|
|
||||||
# Disable multiprocessing if the tokenizer doesn't support it (e.g., mistral_common)
|
|
||||||
if not getattr(self.prompt_tokenizer, "supports_multiprocessing", True):
|
|
||||||
LOG.info(
|
|
||||||
"Disabling multiprocessing for tokenizer as it doesn't support it (e.g., mistral_common)"
|
|
||||||
)
|
|
||||||
num_proc = 1
|
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@622068a"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@865b899"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
|||||||
kd_alpha: 0.9
|
kd_alpha: 0.9
|
||||||
kd_temperature: 1.0
|
kd_temperature: 1.0
|
||||||
|
|
||||||
torch_compile: True # torch>=2.5.1, recommended to reduce vram
|
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
|
|||||||
@@ -122,9 +122,9 @@ def load_lora(
|
|||||||
rank = int(os.environ.get("LOCAL_RANK", 0))
|
rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.fsdp
|
cfg.fsdp_config
|
||||||
and cfg.adapter
|
and cfg.adapter
|
||||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
and rank != 0
|
and rank != 0
|
||||||
):
|
):
|
||||||
setup_quantized_meta_for_peft(model)
|
setup_quantized_meta_for_peft(model)
|
||||||
@@ -152,9 +152,9 @@ def load_lora(
|
|||||||
"Exception caught during model.print_trainable_parameters(): %s", exc
|
"Exception caught during model.print_trainable_parameters(): %s", exc
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
cfg.fsdp
|
cfg.fsdp_config
|
||||||
and cfg.adapter
|
and cfg.adapter
|
||||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
and cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
and rank != 0
|
and rank != 0
|
||||||
):
|
):
|
||||||
setup_quantized_peft_meta_for_training(model)
|
setup_quantized_peft_meta_for_training(model)
|
||||||
|
|||||||
@@ -140,10 +140,15 @@ class ModelLoader:
|
|||||||
"""Check if flash attention is installed."""
|
"""Check if flash attention is installed."""
|
||||||
return find_spec("flash_attn") is not None
|
return find_spec("flash_attn") is not None
|
||||||
|
|
||||||
@cached_property
|
@property
|
||||||
def qlora_fsdp(self):
|
def is_fsdp_enabled(self):
|
||||||
|
"""Property that determines if FSDP is enabled."""
|
||||||
|
return self.cfg.fsdp_config is not None or self.cfg.fsdp is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_qlora_and_fsdp_enabled(self):
|
||||||
"""Property that determines if FSDP with QLoRA is enabled."""
|
"""Property that determines if FSDP with QLoRA is enabled."""
|
||||||
return self.cfg.fsdp and self.cfg.adapter == "qlora"
|
return self.is_fsdp_enabled and self.cfg.adapter == "qlora"
|
||||||
|
|
||||||
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
def load(self) -> tuple[PreTrainedModel | PeftModelForCausalLM, PeftConfig | None]:
|
||||||
"""Load and prepare the model with all configurations and patches.
|
"""Load and prepare the model with all configurations and patches.
|
||||||
@@ -189,15 +194,15 @@ class ModelLoader:
|
|||||||
# Handle PeftModel if needed
|
# Handle PeftModel if needed
|
||||||
if (
|
if (
|
||||||
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
|
isinstance(self.model, (peft.PeftModel, peft.PeftModelForCausalLM))
|
||||||
and not self.qlora_fsdp
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
):
|
):
|
||||||
self.model = self.model.merge_and_unload()
|
self.model = self.model.merge_and_unload()
|
||||||
|
|
||||||
self._resize_token_embeddings()
|
self._resize_token_embeddings()
|
||||||
self._adjust_model_config()
|
self._adjust_model_config()
|
||||||
self._log_memory_usage()
|
|
||||||
self._configure_embedding_dtypes()
|
self._configure_embedding_dtypes()
|
||||||
self._configure_qat()
|
self._configure_qat()
|
||||||
|
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
|
||||||
|
|
||||||
def _resize_token_embeddings(self):
|
def _resize_token_embeddings(self):
|
||||||
"""Resize token embeddings if needed."""
|
"""Resize token embeddings if needed."""
|
||||||
@@ -251,22 +256,13 @@ class ModelLoader:
|
|||||||
):
|
):
|
||||||
self.model.config.eos_token_id = self.tokenizer.eos_token_id
|
self.model.config.eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
|
||||||
def _log_memory_usage(self):
|
|
||||||
"""Log device memory usage after model load."""
|
|
||||||
if hasattr(self.model, "device") and self.model.device.type in (
|
|
||||||
"cuda",
|
|
||||||
"mps",
|
|
||||||
"npu",
|
|
||||||
):
|
|
||||||
log_gpu_memory_usage(LOG, "after model load", self.model.device)
|
|
||||||
|
|
||||||
def _configure_embedding_dtypes(self):
|
def _configure_embedding_dtypes(self):
|
||||||
"""Configure embedding module dtypes."""
|
"""Configure embedding module dtypes."""
|
||||||
# Get embedding modules
|
# Get embedding modules
|
||||||
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
||||||
|
|
||||||
# Initial dtype conversion
|
# Initial dtype conversion
|
||||||
if not self.cfg.fsdp:
|
if not self.is_fsdp_enabled:
|
||||||
# We don't run this during FSDP because this will leave mixed and bfloat16
|
# We don't run this during FSDP because this will leave mixed and bfloat16
|
||||||
# dtypes in the model which FSDP doesn't like
|
# dtypes in the model which FSDP doesn't like
|
||||||
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
|
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
|
||||||
@@ -282,7 +278,7 @@ class ModelLoader:
|
|||||||
self._set_z3_leaf_modules()
|
self._set_z3_leaf_modules()
|
||||||
|
|
||||||
# Apply gradient checkpointing if needed
|
# Apply gradient checkpointing if needed
|
||||||
needs_fa2_dtype = self.cfg.adapter or self.cfg.fsdp
|
needs_fa2_dtype = self.cfg.adapter or self.is_fsdp_enabled
|
||||||
if self.cfg.adapter in ["lora", "qlora"]:
|
if self.cfg.adapter in ["lora", "qlora"]:
|
||||||
needs_fa2_dtype = True
|
needs_fa2_dtype = True
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
@@ -298,10 +294,12 @@ class ModelLoader:
|
|||||||
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
(
|
(
|
||||||
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
|
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
|
||||||
and not self.qlora_fsdp
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
)
|
)
|
||||||
|
or (
|
||||||
# CCE requires embedding layers to be in fp16/bf16 for backward pass
|
# CCE requires embedding layers to be in fp16/bf16 for backward pass
|
||||||
or self.cfg.cut_cross_entropy
|
self.cfg.cut_cross_entropy
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if should_convert:
|
if should_convert:
|
||||||
@@ -357,7 +355,6 @@ class ModelLoader:
|
|||||||
and not (self.cfg.rl and self.cfg.load_in_4bit)
|
and not (self.cfg.rl and self.cfg.load_in_4bit)
|
||||||
and not skip_move_to_device
|
and not skip_move_to_device
|
||||||
):
|
):
|
||||||
# TODO: validate this conditional
|
|
||||||
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
|
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
|
||||||
|
|
||||||
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||||
@@ -430,7 +427,17 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||||
|
|
||||||
if not is_deepspeed_zero3_enabled():
|
is_ds_zero3 = is_deepspeed_zero3_enabled()
|
||||||
|
|
||||||
|
# FSDP requires control over device placement, so don't set device_map when FSDP is enabled
|
||||||
|
if self.is_fsdp_enabled:
|
||||||
|
# For QLoRA + FSDP, we still need to set device_map to "auto" for proper initialization
|
||||||
|
if self.is_qlora_and_fsdp_enabled:
|
||||||
|
self.model_kwargs["device_map"] = {
|
||||||
|
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
}
|
||||||
|
# For other FSDP cases, don't set device_map at all
|
||||||
|
elif not is_ds_zero3:
|
||||||
self.model_kwargs["device_map"] = device_map
|
self.model_kwargs["device_map"] = device_map
|
||||||
|
|
||||||
cur_device = get_device_type()
|
cur_device = get_device_type()
|
||||||
@@ -499,7 +506,7 @@ class ModelLoader:
|
|||||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||||
}
|
}
|
||||||
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
||||||
self.cfg.deepspeed or self.cfg.fsdp
|
self.cfg.deepspeed or self.is_fsdp_enabled
|
||||||
):
|
):
|
||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
@@ -604,9 +611,21 @@ class ModelLoader:
|
|||||||
def _build_model(self) -> bool:
|
def _build_model(self) -> bool:
|
||||||
"""Load model, with load strategy depending on config."""
|
"""Load model, with load strategy depending on config."""
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
|
if self.is_fsdp_enabled:
|
||||||
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
|
skip_move_to_device = True
|
||||||
|
# Don't delete device_map for QLoRA + FSDP - it was set correctly in _set_device_map
|
||||||
if (
|
if (
|
||||||
self.qlora_fsdp
|
"device_map" in self.model_kwargs
|
||||||
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
and not self.is_qlora_and_fsdp_enabled
|
||||||
|
):
|
||||||
|
del self.model_kwargs["device_map"]
|
||||||
|
elif self.is_qlora_and_fsdp_enabled:
|
||||||
|
skip_move_to_device = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.is_qlora_and_fsdp_enabled
|
||||||
|
and self.cfg.fsdp_config.cpu_ram_efficient_loading
|
||||||
and (
|
and (
|
||||||
self.cfg.model_config_type == "dbrx"
|
self.cfg.model_config_type == "dbrx"
|
||||||
or self.cfg.qlora_sharded_model_loading
|
or self.cfg.qlora_sharded_model_loading
|
||||||
@@ -632,12 +651,6 @@ class ModelLoader:
|
|||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
and not self.cfg.gptq
|
and not self.cfg.gptq
|
||||||
):
|
):
|
||||||
# TODO: Do we need to open this up for all models?
|
|
||||||
if self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
|
||||||
skip_move_to_device = True
|
|
||||||
if "device_map" in self.model_kwargs:
|
|
||||||
del self.model_kwargs["device_map"]
|
|
||||||
|
|
||||||
# Please don't remove underscore binding without reading the fn docstring.
|
# Please don't remove underscore binding without reading the fn docstring.
|
||||||
_ = self._configure_zero3_memory_efficient_loading()
|
_ = self._configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
@@ -691,8 +704,7 @@ class ModelLoader:
|
|||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
elif self.cfg.gptq:
|
||||||
if self.cfg.gptq:
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -700,18 +712,8 @@ class ModelLoader:
|
|||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if (
|
|
||||||
self.cfg.fsdp
|
|
||||||
and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
|
||||||
):
|
|
||||||
# disabling either of these two still leads to VRAM spike before setting back down
|
|
||||||
skip_move_to_device = True
|
|
||||||
if "device_map" in self.model_kwargs:
|
|
||||||
del self.model_kwargs["device_map"]
|
|
||||||
|
|
||||||
# Please don't remove underscore binding without reading the fn docstring.
|
# Please don't remove underscore binding without reading the fn docstring.
|
||||||
_ = self._configure_zero3_memory_efficient_loading()
|
_ = self._configure_zero3_memory_efficient_loading()
|
||||||
|
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -753,8 +755,8 @@ class ModelLoader:
|
|||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.qlora_fsdp
|
self.is_qlora_and_fsdp_enabled
|
||||||
or (self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
|
or (self.is_fsdp_enabled and self.cfg.fsdp_config.cpu_ram_efficient_loading)
|
||||||
or is_deepspeed_zero3_enabled()
|
or is_deepspeed_zero3_enabled()
|
||||||
):
|
):
|
||||||
# Make sure everything is in the same dtype
|
# Make sure everything is in the same dtype
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import importlib.util
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
@@ -93,10 +94,14 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_fsdp_patches(self):
|
def _apply_fsdp_patches(self):
|
||||||
"""Apply patches for FSDP configurations."""
|
"""Apply patches for FSDP configurations."""
|
||||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
if self.cfg.fsdp_config and str(self.cfg.fsdp_version) == "2":
|
||||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp2
|
||||||
|
|
||||||
patch_accelerate_fsdp2()
|
patch_accelerate_fsdp2()
|
||||||
|
if self.cfg.rl:
|
||||||
|
from axolotl.monkeypatch.trainer.trl import patch_trl_prepare_fsdp2
|
||||||
|
|
||||||
|
patch_trl_prepare_fsdp2()
|
||||||
|
|
||||||
# if self.cfg.fsdp_config:
|
# if self.cfg.fsdp_config:
|
||||||
# # see transformers#39152
|
# # see transformers#39152
|
||||||
@@ -165,10 +170,25 @@ class PatchManager:
|
|||||||
"""Apply patches for gradient checkpointing."""
|
"""Apply patches for gradient checkpointing."""
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||||
|
CheckpointFunctionWithCPUOffload,
|
||||||
hf_grad_checkpoint_offload_wrapper,
|
hf_grad_checkpoint_offload_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
if (
|
||||||
|
self.cfg.gradient_checkpointing_kwargs
|
||||||
|
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
|
||||||
|
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
|
||||||
|
):
|
||||||
|
transformers.modeling_utils.checkpoint = (
|
||||||
|
hf_grad_checkpoint_offload_wrapper
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transformers.modeling_utils.checkpoint.CheckpointFunction = (
|
||||||
|
CheckpointFunctionWithCPUOffload
|
||||||
|
)
|
||||||
|
torch.utils.checkpoint.CheckpointFunction = (
|
||||||
|
CheckpointFunctionWithCPUOffload
|
||||||
|
)
|
||||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||||
hf_grad_checkpoint_disk_offload_wrapper,
|
hf_grad_checkpoint_disk_offload_wrapper,
|
||||||
|
|||||||
@@ -195,9 +195,11 @@ def ensure_dtype(model: PreTrainedModel, dtype: torch.dtype = torch.bfloat16):
|
|||||||
bias_mismatch = module.bias.dtype != dtype
|
bias_mismatch = module.bias.dtype != dtype
|
||||||
|
|
||||||
if weight_mismatch:
|
if weight_mismatch:
|
||||||
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
|
LOG.debug(
|
||||||
|
f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}"
|
||||||
|
)
|
||||||
if bias_mismatch:
|
if bias_mismatch:
|
||||||
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
|
LOG.debug(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
|
||||||
if weight_mismatch or bias_mismatch:
|
if weight_mismatch or bias_mismatch:
|
||||||
module.to(dtype)
|
module.to(dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -2,102 +2,65 @@
|
|||||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import functools
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
def fsdp2_load_full_state_dict(
|
||||||
|
_accelerator, model: torch.nn.Module, full_sd: dict, offload_to_cpu: bool = False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
Loads the full state dict (could be only on rank 0) into the sharded model. This is done by broadcasting the
|
||||||
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
parameters from rank 0 to all other ranks. This function modifies the model in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
accelerator (`Accelerator`): The accelerator instance
|
accelerator (`Accelerator`): The accelerator instance
|
||||||
model (`torch.nn.Module`):
|
model (`torch.nn.Module`):
|
||||||
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
|
The model to load the state dict into, expected to be on meta device or a VRAM spike can occur
|
||||||
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
full_sd (`dict`): The full state dict to load, can only be on rank 0
|
||||||
"""
|
"""
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.distributed.tensor import distribute_tensor
|
from torch.distributed.tensor import distribute_tensor
|
||||||
|
|
||||||
# Model was previously copied to meta device
|
LOG.info("Broadcasting full state dict to all ranks...")
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
meta_sharded_sd = model.state_dict()
|
meta_sharded_sd = model.state_dict()
|
||||||
sharded_sd = {}
|
sharded_sd = {}
|
||||||
|
for param_name, full_tensor in full_sd.items():
|
||||||
# Rank 0 distributes the full state dict to other ranks
|
sharded_meta_param = meta_sharded_sd.get(param_name)
|
||||||
def _infer_parameter_dtype(model, param_name, empty_param):
|
full_tensor = full_tensor.to(sharded_meta_param.dtype).to(torch.device("cuda"))
|
||||||
try:
|
if hasattr(sharded_meta_param, "device_mesh"):
|
||||||
old_param = model.get_parameter_or_buffer(param_name)
|
sharded_param = distribute_tensor(
|
||||||
except AttributeError:
|
|
||||||
# Need this for LORA, as there some params are not *parameters* of sorts
|
|
||||||
base_param_name, local_param_name = param_name.rsplit(".", 1)
|
|
||||||
submodule = model.get_submodule(base_param_name)
|
|
||||||
old_param = getattr(submodule, local_param_name)
|
|
||||||
|
|
||||||
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
|
|
||||||
casting_dtype = None
|
|
||||||
is_param_float8_e4m3fn = (
|
|
||||||
is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
|
|
||||||
casting_dtype = old_param.dtype
|
|
||||||
|
|
||||||
return old_param is not None and old_param.is_contiguous(), casting_dtype
|
|
||||||
|
|
||||||
def _cast_and_contiguous(tensor, to_contiguous, dtype):
|
|
||||||
if dtype is not None:
|
|
||||||
tensor = tensor.to(dtype=dtype)
|
|
||||||
if to_contiguous:
|
|
||||||
tensor = tensor.contiguous()
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
param_names = sorted(meta_sharded_sd.keys())
|
|
||||||
|
|
||||||
for param_name in param_names:
|
|
||||||
mesh = meta_sharded_sd[param_name].device_mesh
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
full_param = full_sd[param_name].detach().cuda()
|
|
||||||
dist.broadcast(full_param, src=0, group=mesh.get_group())
|
|
||||||
sharded_tensor = distribute_tensor(
|
|
||||||
full_param, mesh, sharded_sd[param_name].placements
|
|
||||||
)
|
|
||||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
|
||||||
model,
|
|
||||||
param_name,
|
|
||||||
full_param,
|
|
||||||
)
|
|
||||||
sharded_tensor = _cast_and_contiguous(
|
|
||||||
sharded_tensor, to_contiguous, casting_dtype
|
|
||||||
)
|
|
||||||
sharded_sd[param_name] = sharded_tensor
|
|
||||||
else:
|
|
||||||
full_tensor = torch.empty(
|
|
||||||
sharded_sd[param_name].size(),
|
|
||||||
device="cuda",
|
|
||||||
dtype=sharded_sd[param_name].dtype,
|
|
||||||
)
|
|
||||||
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
|
|
||||||
sharded_tensor = distribute_tensor(
|
|
||||||
full_tensor, mesh, sharded_sd[param_name].placements
|
|
||||||
)
|
|
||||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
|
||||||
model,
|
|
||||||
param_name,
|
|
||||||
full_tensor,
|
full_tensor,
|
||||||
|
sharded_meta_param.device_mesh,
|
||||||
|
sharded_meta_param.placements,
|
||||||
|
src_data_rank=0,
|
||||||
)
|
)
|
||||||
sharded_tensor = _cast_and_contiguous(
|
else:
|
||||||
sharded_tensor, to_contiguous, casting_dtype
|
sharded_param = full_tensor
|
||||||
)
|
|
||||||
sharded_sd[param_name] = sharded_tensor
|
|
||||||
|
|
||||||
# we set `assign=True` because our params are on meta device
|
if offload_to_cpu:
|
||||||
model.load_state_dict(sharded_sd, assign=True)
|
sharded_param = sharded_param.cpu()
|
||||||
|
|
||||||
|
sharded_sd[param_name] = nn.Parameter(sharded_param)
|
||||||
|
del full_tensor
|
||||||
|
full_sd[param_name] = None
|
||||||
|
model.load_state_dict(sharded_sd, assign=True, strict=True)
|
||||||
|
end_time = time.time()
|
||||||
|
LOG.debug(
|
||||||
|
f"Time taken to load full state dict: {(end_time - start_time):.2f} seconds"
|
||||||
|
)
|
||||||
|
log_gpu_memory_usage(LOG, "Memory usage after broadcasting full state dict", 0)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@@ -191,17 +154,195 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def patch_accelerate_fsdp2():
|
def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
|
||||||
import accelerate
|
"""Helper function to process LoRA modules for FSDP2."""
|
||||||
from accelerate.utils import fsdp_utils
|
from torch.distributed.fsdp import fully_shard
|
||||||
|
|
||||||
fsdp_utils.fsdp2_load_full_state_dict = fsdp2_load_full_state_dict
|
log_bias_dtype_mismatch = False
|
||||||
setattr(
|
|
||||||
sys.modules["accelerate.utils.fsdp_utils"],
|
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
|
||||||
"fsdp2_load_full_state_dict",
|
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
|
||||||
fsdp2_load_full_state_dict,
|
if module.base_layer.bias is not None:
|
||||||
|
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
|
||||||
|
log_bias_dtype_mismatch = True
|
||||||
|
module.base_layer.bias.data = module.base_layer.bias.data.to(
|
||||||
|
module.base_layer.weight.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for active_adapter in module.active_adapters:
|
||||||
|
if module.lora_A:
|
||||||
|
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
|
||||||
|
if module.lora_B:
|
||||||
|
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
|
||||||
|
if module.lora_embedding_A:
|
||||||
|
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
|
||||||
|
if module.lora_embedding_B:
|
||||||
|
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
|
||||||
|
if module.lora_magnitude_vector:
|
||||||
|
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
|
||||||
|
return log_bias_dtype_mismatch
|
||||||
|
|
||||||
|
|
||||||
|
def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||||
|
"""Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
accelerator (`Accelerator`): The accelerator instance
|
||||||
|
model (`torch.nn.Module`): The model to prepare
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.nn.Module`: Prepared model
|
||||||
|
"""
|
||||||
|
from accelerate.utils import get_module_children_bottom_up, is_compiled_module
|
||||||
|
from accelerate.utils.fsdp_utils import fsdp2_prepare_auto_wrap_policy
|
||||||
|
from accelerate.utils.modeling import get_non_persistent_buffers
|
||||||
|
from peft import PeftModel
|
||||||
|
from peft.tuners.lora import LoraLayer
|
||||||
|
from torch.distributed.fsdp import (
|
||||||
|
CPUOffloadPolicy,
|
||||||
|
FSDPModule,
|
||||||
|
MixedPrecisionPolicy,
|
||||||
|
fully_shard,
|
||||||
|
)
|
||||||
|
|
||||||
|
is_type_fsdp = isinstance(model, FSDPModule) or (
|
||||||
|
is_compiled_module(model)
|
||||||
|
and isinstance(model._orig_mod, FSDPModule) # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
if is_type_fsdp:
|
||||||
|
return model
|
||||||
|
|
||||||
|
fsdp2_plugin = accelerator.state.fsdp_plugin
|
||||||
|
|
||||||
|
original_sd = model.state_dict()
|
||||||
|
|
||||||
|
from torch.distributed.fsdp.wrap import (
|
||||||
|
size_based_auto_wrap_policy,
|
||||||
|
transformer_auto_wrap_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
|
||||||
|
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
||||||
|
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
||||||
|
pass # auto_wrap_policy_type = "transformer"
|
||||||
|
elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy:
|
||||||
|
pass # auto_wrap_policy_type = "size"
|
||||||
|
|
||||||
|
# We set `auto_wrap_policy` to `functools.partial` to avoid creating it again
|
||||||
|
# This is because of `apply_activation_checkpointing` which will can reuse this function
|
||||||
|
fsdp2_plugin.set_auto_wrap_policy(model)
|
||||||
|
|
||||||
|
if fsdp2_plugin.activation_checkpointing:
|
||||||
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
|
CheckpointImpl,
|
||||||
|
apply_activation_checkpointing,
|
||||||
|
checkpoint_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply activation checkpointing before applying `fully_shard`
|
||||||
|
apply_activation_checkpointing(
|
||||||
|
model,
|
||||||
|
checkpoint_wrapper_fn=functools.partial(
|
||||||
|
checkpoint_wrapper,
|
||||||
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
||||||
|
),
|
||||||
|
auto_wrap_policy=fsdp2_plugin.auto_wrap_policy,
|
||||||
|
)
|
||||||
|
|
||||||
|
fsdp2_kwargs = {
|
||||||
|
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||||
|
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||||
|
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||||
|
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||||
|
}
|
||||||
|
|
||||||
|
model_has_params4bit = False
|
||||||
|
for _, param in model.named_parameters():
|
||||||
|
# this is a temporary fix whereby loading models with bnb params cannot be moved from
|
||||||
|
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
|
||||||
|
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
|
||||||
|
if param.__class__.__name__ == "Params4bit":
|
||||||
|
model_has_params4bit = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||||
|
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
|
||||||
|
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
|
||||||
|
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
|
||||||
|
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
|
||||||
|
|
||||||
|
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
|
||||||
|
# Also, these buffers aren't getting sharded by default
|
||||||
|
# We get the FQNs of all non-persistent buffers, to re-register them after
|
||||||
|
non_persistent_buffer_fqns = get_non_persistent_buffers(
|
||||||
|
model, recurse=True, fqns=True
|
||||||
|
)
|
||||||
|
original_non_persistent_buffers = copy.deepcopy(
|
||||||
|
{k: v for k, v in model.named_buffers() if k in non_persistent_buffer_fqns}
|
||||||
|
)
|
||||||
|
# We move the model to meta device, as then sharding happens on meta device
|
||||||
|
model = model.to(torch.device("meta"))
|
||||||
|
# We need to re-tie the weights, not exactly sure why, but if we don't do this, reference to `lm_head/embed_tokens` stay hanging -> more VRAM usage
|
||||||
|
# We assume `transformers` models have a `tie_weights` method if they support it
|
||||||
|
if hasattr(model, "tie_weights"):
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
is_peft_model = isinstance(model, PeftModel)
|
||||||
|
|
||||||
|
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
|
||||||
|
log_bias_dtype_mismatch = False
|
||||||
|
if auto_wrap_policy is not None:
|
||||||
|
for module in get_module_children_bottom_up(model)[:-1]:
|
||||||
|
if is_peft_model and isinstance(module, LoraLayer):
|
||||||
|
module_log_bias_mismatch = _process_lora_module_for_fsdp(
|
||||||
|
module, fsdp2_kwargs
|
||||||
|
)
|
||||||
|
log_bias_dtype_mismatch |= module_log_bias_mismatch
|
||||||
|
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
|
||||||
|
fully_shard(module, **fsdp2_kwargs)
|
||||||
|
|
||||||
|
fully_shard(model, **fsdp2_kwargs)
|
||||||
|
|
||||||
|
if log_bias_dtype_mismatch:
|
||||||
|
LOG.warning(
|
||||||
|
"Bias dtype mismatch detected in LoRA base linear layer. Bias parameters have been cast to weight dtype."
|
||||||
|
)
|
||||||
|
|
||||||
|
if fsdp2_plugin.cpu_ram_efficient_loading:
|
||||||
|
offload_to_cpu = isinstance(fsdp2_plugin.cpu_offload, CPUOffloadPolicy)
|
||||||
|
fsdp2_load_full_state_dict(
|
||||||
|
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
|
||||||
|
)
|
||||||
|
|
||||||
|
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
|
||||||
|
# We re-register the buffers, as they may not be in the state_dict
|
||||||
|
for fqn, buffer_tensor in original_non_persistent_buffers.items():
|
||||||
|
buffer_tensor = buffer_tensor.to(accelerator.device)
|
||||||
|
|
||||||
|
if "." in fqn:
|
||||||
|
parent_fqn, local_buffer_name = fqn.rsplit(".", 1)
|
||||||
|
parent_module = model.get_submodule(parent_fqn)
|
||||||
|
else:
|
||||||
|
local_buffer_name = fqn
|
||||||
|
parent_module = model
|
||||||
|
|
||||||
|
parent_module.register_buffer(
|
||||||
|
local_buffer_name, buffer_tensor, persistent=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to tie the weights again, as call to `load_full_state_dict` breaks the tie
|
||||||
|
# Needs to be called both here and above
|
||||||
|
# removing this call makes the have slightly different loss
|
||||||
|
# removing the call above leads to extra memory usage as explained in the comment above
|
||||||
|
if hasattr(model, "tie_weights"):
|
||||||
|
model.tie_weights()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def patch_accelerate_fsdp2():
|
||||||
|
import accelerate
|
||||||
|
|
||||||
|
accelerate.accelerator.fsdp2_prepare_model = fsdp2_prepare_model
|
||||||
accelerate.Accelerator.get_state_dict = get_state_dict
|
accelerate.Accelerator.get_state_dict = get_state_dict
|
||||||
setattr(
|
setattr(
|
||||||
sys.modules["accelerate"],
|
sys.modules["accelerate"],
|
||||||
|
|||||||
@@ -6,6 +6,10 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
||||||
# TODO remove this patch when transformers#37285 is merged and in a release
|
# TODO remove this patch when transformers#37285 is merged and in a release
|
||||||
@@ -46,10 +50,15 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
|
|||||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
|
||||||
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
# see https://github.com/pytorch/pytorch/issues/146260 for training
|
||||||
self.training = training
|
self.training = training
|
||||||
|
LOG.info(
|
||||||
|
"Compiling flex attention with kwargs: %s. This may take a while...",
|
||||||
|
flex_attn_compile_kwargs,
|
||||||
|
)
|
||||||
self._compiled_flex_attention = torch.compile(
|
self._compiled_flex_attention = torch.compile(
|
||||||
flex_attention,
|
flex_attention,
|
||||||
**flex_attn_compile_kwargs,
|
**flex_attn_compile_kwargs,
|
||||||
)
|
)
|
||||||
|
LOG.info("Flex attention compiled successfully.")
|
||||||
self._is_flex_compiled = True
|
self._is_flex_compiled = True
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import (
|
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
|
||||||
|
CheckpointFunctionWithCPUOffload,
|
||||||
CPU_Offloaded_Gradient_Checkpointer,
|
CPU_Offloaded_Gradient_Checkpointer,
|
||||||
)
|
)
|
||||||
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
||||||
|
|||||||
@@ -13,8 +13,24 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from torch.utils.checkpoint import (
|
||||||
|
_get_autocast_kwargs,
|
||||||
|
_get_device_module,
|
||||||
|
_infer_device_type,
|
||||||
|
check_backward_validity,
|
||||||
|
detach_variable,
|
||||||
|
get_device_states,
|
||||||
|
set_device_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# support different pytorch versions
|
||||||
|
has_device_type = "device_type" in inspect.signature(set_device_states).parameters
|
||||||
|
|
||||||
torch_version = version.parse(torch.__version__)
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
@@ -60,3 +76,153 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
) + (
|
) + (
|
||||||
None,
|
None,
|
||||||
) * len(ctx.args)
|
) * len(ctx.args)
|
||||||
|
|
||||||
|
|
||||||
|
# Copyright 2025 Snowflake Inc.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
|
||||||
|
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
|
||||||
|
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||||
|
check_backward_validity(args)
|
||||||
|
ctx.run_function = run_function
|
||||||
|
ctx.preserve_rng_state = preserve_rng_state
|
||||||
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||||
|
ctx.device_type = _infer_device_type(*args)
|
||||||
|
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
||||||
|
ctx.device_type
|
||||||
|
)
|
||||||
|
if preserve_rng_state:
|
||||||
|
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||||
|
# Don't eagerly initialize the cuda context by accident.
|
||||||
|
# (If the user intends that the context is initialized later, within their
|
||||||
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||||
|
# we have no way to anticipate this will happen before we run the function.)
|
||||||
|
ctx.had_device_in_fwd = False
|
||||||
|
device_module = _get_device_module(ctx.device_type)
|
||||||
|
if getattr(device_module, "_initialized", False):
|
||||||
|
ctx.had_device_in_fwd = True
|
||||||
|
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
|
||||||
|
|
||||||
|
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||||
|
# to be filled out during the backward.
|
||||||
|
ctx.inputs = []
|
||||||
|
ctx.tensor_indices = []
|
||||||
|
tensor_inputs = []
|
||||||
|
# x = None
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if torch.is_tensor(arg):
|
||||||
|
# cpu-offload
|
||||||
|
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
|
||||||
|
# upstream could accept a list of arg indices to offload
|
||||||
|
if i == 0:
|
||||||
|
# print(f"{arg.shape=}")
|
||||||
|
ctx.x_device = arg.device
|
||||||
|
ctx.x_requires_grad = arg.requires_grad
|
||||||
|
t = arg.detach().cpu()
|
||||||
|
else:
|
||||||
|
t = arg
|
||||||
|
tensor_inputs.append(t)
|
||||||
|
ctx.tensor_indices.append(i)
|
||||||
|
ctx.inputs.append(None)
|
||||||
|
else:
|
||||||
|
ctx.inputs.append(arg)
|
||||||
|
|
||||||
|
ctx.save_for_backward(*tensor_inputs)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = run_function(*args)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, *args):
|
||||||
|
if (
|
||||||
|
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
|
||||||
|
" with .grad() or passing an `inputs` parameter to .backward()."
|
||||||
|
" To resolve this error, you can either set use_reentrant=False,"
|
||||||
|
" or call .backward() without passing the `inputs` argument."
|
||||||
|
)
|
||||||
|
# Copy the list to avoid modifying original list.
|
||||||
|
inputs = list(ctx.inputs)
|
||||||
|
tensor_indices = ctx.tensor_indices
|
||||||
|
tensors = ctx.saved_tensors
|
||||||
|
|
||||||
|
# Fill in inputs with appropriate saved tensors.
|
||||||
|
for i, idx in enumerate(tensor_indices):
|
||||||
|
if i == 0:
|
||||||
|
t = (
|
||||||
|
tensors[i]
|
||||||
|
.to(ctx.x_device)
|
||||||
|
.detach()
|
||||||
|
.requires_grad_(ctx.x_requires_grad)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
t = tensors[i]
|
||||||
|
inputs[idx] = t
|
||||||
|
|
||||||
|
# Stash the surrounding rng state, and mimic the state that was
|
||||||
|
# present at this time during forward. Restore the surrounding state
|
||||||
|
# when we're done.
|
||||||
|
rng_devices = []
|
||||||
|
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
||||||
|
rng_devices = ctx.fwd_devices
|
||||||
|
with torch.random.fork_rng(
|
||||||
|
devices=rng_devices,
|
||||||
|
enabled=ctx.preserve_rng_state,
|
||||||
|
device_type=ctx.device_type,
|
||||||
|
):
|
||||||
|
if ctx.preserve_rng_state:
|
||||||
|
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||||
|
if ctx.had_device_in_fwd:
|
||||||
|
if has_device_type:
|
||||||
|
# newer pytorch (as early as 2.7)
|
||||||
|
set_device_states(
|
||||||
|
ctx.fwd_devices,
|
||||||
|
ctx.fwd_device_states,
|
||||||
|
device_type=ctx.device_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# older pytorch (at least 2.4)
|
||||||
|
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
||||||
|
detached_inputs = detach_variable(tuple(inputs))
|
||||||
|
|
||||||
|
device_autocast_ctx = (
|
||||||
|
torch.amp.autocast(
|
||||||
|
device_type=ctx.device_type, **ctx.device_autocast_kwargs
|
||||||
|
)
|
||||||
|
if torch.amp.is_autocast_available(ctx.device_type)
|
||||||
|
else contextlib.nullcontext()
|
||||||
|
)
|
||||||
|
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
|
||||||
|
outputs = ctx.run_function(*detached_inputs)
|
||||||
|
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
outputs = (outputs,)
|
||||||
|
|
||||||
|
# run backward() with only tensor that requires grad
|
||||||
|
outputs_with_grad = []
|
||||||
|
args_with_grad = []
|
||||||
|
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
|
||||||
|
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||||
|
outputs_with_grad.append(outputs[i])
|
||||||
|
args_with_grad.append(args[i])
|
||||||
|
if len(outputs_with_grad) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"none of output has requires_grad=True, this checkpoint() is not necessary"
|
||||||
|
)
|
||||||
|
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||||
|
grads = tuple(
|
||||||
|
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||||
|
for inp in detached_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
return (None, None) + grads
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"deepseek_v3",
|
"deepseek_v3",
|
||||||
"glm",
|
"glm",
|
||||||
"glm4",
|
"glm4",
|
||||||
|
"smollm3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Monkeypatch for Tiled MLP implementation"""
|
"""Monkeypatch for Tiled MLP implementation"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -29,12 +30,15 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
|||||||
|
|
||||||
mlp_forward = torch.compile(generic_mlp_forward)
|
mlp_forward = torch.compile(generic_mlp_forward)
|
||||||
|
|
||||||
|
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
||||||
|
|
||||||
def tiled_mlp_forward(self, x):
|
def tiled_mlp_forward(self, x):
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
seqlen = input_shape[-2]
|
seqlen = input_shape[-2]
|
||||||
hidden = input_shape[-1]
|
hidden = input_shape[-1]
|
||||||
if cfg_num_shards is None:
|
if cfg_num_shards is None:
|
||||||
num_shards = math.ceil(seqlen / hidden)
|
num_shards = math.ceil(seqlen / hidden)
|
||||||
|
if is_distributed:
|
||||||
num_shards_tensor = torch.tensor(num_shards, device=x.device)
|
num_shards_tensor = torch.tensor(num_shards, device=x.device)
|
||||||
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
|
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
|
||||||
num_shards = num_shards_tensor.item()
|
num_shards = num_shards_tensor.item()
|
||||||
|
|||||||
13
src/axolotl/monkeypatch/trainer/trl.py
Normal file
13
src/axolotl/monkeypatch/trainer/trl.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""Monkeypatch for TRL trainer FSDP preparation."""
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_fsdp(model, accelerator):
|
||||||
|
from axolotl.monkeypatch.accelerate.fsdp2 import fsdp2_prepare_model
|
||||||
|
|
||||||
|
return fsdp2_prepare_model(accelerator, model)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_trl_prepare_fsdp2():
|
||||||
|
import trl.models.utils
|
||||||
|
|
||||||
|
trl.models.utils.prepare_fsdp = prepare_fsdp
|
||||||
@@ -681,13 +681,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
transformed_message = self.transform_message(message)
|
transformed_message = self.transform_message(message)
|
||||||
|
|
||||||
turn = {
|
turn = transformed_message
|
||||||
**transformed_message,
|
|
||||||
"training": message.get(self.prompter.message_field_training),
|
training = message.get(self.prompter.message_field_training)
|
||||||
"training_detail": message.get(
|
training_detail = message.get(self.prompter.message_field_training_detail)
|
||||||
self.prompter.message_field_training_detail
|
if training is not None:
|
||||||
),
|
turn["training"] = training
|
||||||
}
|
if training_detail is not None:
|
||||||
|
turn["training_detail"] = training_detail
|
||||||
|
|
||||||
turns.append(turn)
|
turns.append(turn)
|
||||||
|
|
||||||
@@ -859,15 +860,6 @@ class MistralStrategy(ChatTemplateStrategy):
|
|||||||
# TODO: address this in the future with mistral-specific checks
|
# TODO: address this in the future with mistral-specific checks
|
||||||
# self._validate_eot_and_eos_tokens()
|
# self._validate_eot_and_eos_tokens()
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_multiprocessing(self) -> bool:
|
|
||||||
"""
|
|
||||||
Whether this tokenizing strategy supports multiprocessing.
|
|
||||||
mistral_common tokenizers cannot be pickled for multiprocessing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def find_first_eot_token(self, input_ids, start_idx):
|
def find_first_eot_token(self, input_ids, start_idx):
|
||||||
"""Find the first EOT token in the input_ids starting from start_idx."""
|
"""Find the first EOT token in the input_ids starting from start_idx."""
|
||||||
# mistral-common tokenizer does not support eot_tokens
|
# mistral-common tokenizer does not support eot_tokens
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
|||||||
system=sample[field_system], prompt=sample[field_prompt]
|
system=sample[field_system], prompt=sample[field_prompt]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
sample["prompt"] = prompt_format.format(prompt=sample[field_prompt])
|
||||||
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
|
||||||
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
|
||||||
return sample
|
return sample
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user