Compare commits
9 Commits
lora_kerne
...
nd_paralle
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc2bc688d8 | ||
|
|
b3c04dd9fe | ||
|
|
972c719d38 | ||
|
|
2c1cb8b300 | ||
|
|
cca207eec4 | ||
|
|
9a2da4d9f0 | ||
|
|
8fe4758e94 | ||
|
|
8c641fdcb4 | ||
|
|
5c74bebfd0 |
4
.github/workflows/base.yml
vendored
4
.github/workflows/base.yml
vendored
@@ -17,7 +17,7 @@ on:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-base:
|
build-base:
|
||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
@@ -108,7 +108,7 @@ jobs:
|
|||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||||
build-base-uv:
|
build-base-uv:
|
||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
timeout-minutes: 480
|
timeout-minutes: 480
|
||||||
runs-on: ubuntu-latest-m
|
runs-on: ubuntu-latest-m
|
||||||
strategy:
|
strategy:
|
||||||
|
|||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -3,7 +3,6 @@ on:
|
|||||||
# check on PRs, and manual triggers
|
# check on PRs, and manual triggers
|
||||||
merge_group:
|
merge_group:
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
@@ -17,7 +16,6 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
|
|||||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,7 +21,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -37,14 +37,14 @@ jobs:
|
|||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.1
|
pytorch: 2.7.1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
|
|||||||
6
.github/workflows/preview-docs.yml
vendored
6
.github/workflows/preview-docs.yml
vendored
@@ -2,7 +2,7 @@ name: Preview
|
|||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened]
|
||||||
|
|
||||||
# Run the workflow only when one of these files changes
|
# Run the workflow only when one of these files changes
|
||||||
paths:
|
paths:
|
||||||
@@ -25,7 +25,6 @@ permissions:
|
|||||||
jobs:
|
jobs:
|
||||||
preview:
|
preview:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository
|
- name: Check out repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -53,7 +52,6 @@ jobs:
|
|||||||
|
|
||||||
- name: Netlify Publish
|
- name: Netlify Publish
|
||||||
uses: nwtgck/actions-netlify@v3.0
|
uses: nwtgck/actions-netlify@v3.0
|
||||||
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
|
|
||||||
id: netlify
|
id: netlify
|
||||||
with:
|
with:
|
||||||
publish-dir: './_site'
|
publish-dir: './_site'
|
||||||
@@ -68,7 +66,7 @@ jobs:
|
|||||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||||
|
|
||||||
- name: Update PR with preview link
|
- name: Update PR with preview link
|
||||||
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
|
if: ${{ steps.netlify.outcome == 'success' }}
|
||||||
uses: marocchino/sticky-pull-request-comment@v2
|
uses: marocchino/sticky-pull-request-comment@v2
|
||||||
with:
|
with:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -13,7 +13,6 @@ on:
|
|||||||
- 'cicd/cicd.sh'
|
- 'cicd/cicd.sh'
|
||||||
- 'cicd/Dockerfile.jinja'
|
- 'cicd/Dockerfile.jinja'
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
@@ -35,7 +34,6 @@ jobs:
|
|||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
@@ -49,7 +47,6 @@ jobs:
|
|||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
# needs: [preload-cache]
|
# needs: [preload-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
@@ -124,7 +121,6 @@ jobs:
|
|||||||
pytest-sdist:
|
pytest-sdist:
|
||||||
name: PyTest from Source Dist
|
name: PyTest from Source Dist
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -189,7 +185,7 @@ jobs:
|
|||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests-1st:
|
||||||
# Run this job first as a gate for running the remainder of the test matrix
|
# Run this job first as a gate for running the remainder of the test matrix
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
@@ -239,7 +235,7 @@ jobs:
|
|||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
@@ -293,7 +289,6 @@ jobs:
|
|||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
needs: [docker-e2e-tests]
|
needs: [docker-e2e-tests]
|
||||||
if: ${{ !github.event.pull_request.draft }}
|
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -119,15 +119,14 @@ datasets:
|
|||||||
|
|
||||||
## Dataset Processing
|
## Dataset Processing
|
||||||
|
|
||||||
| Option | Default | Description |
|
| Option | Default | Description |
|
||||||
| --------------------------------- | -------------------------- | ----------------------------------- |
|
| ----------------------------- | -------------------------- | --------------------------------- |
|
||||||
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
|
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
|
||||||
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
|
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
|
||||||
| `dataset_processes` | `4` | Number of preprocessing processes |
|
| `dataset_processes` | `4` | Number of preprocessing processes |
|
||||||
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
|
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
|
||||||
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
|
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
|
||||||
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |
|
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
|
||||||
| `dataset_exact_deduplication` | `true` | Deduplicate datasets |
|
|
||||||
|
|
||||||
## LoRA Configuration
|
## LoRA Configuration
|
||||||
|
|
||||||
|
|||||||
21
README.md
21
README.md
@@ -25,7 +25,6 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
|
||||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||||
@@ -80,20 +79,6 @@ docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
|||||||
|
|
||||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
#### Cloud Providers
|
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
|
||||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=github&utm_medium=developer_community&utm_campaign=template_launch_axolotl&utm_content=readme)
|
|
||||||
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
|
|
||||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
|
|
||||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
|
||||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
|
||||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### Your First Fine-tune
|
### Your First Fine-tune
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -135,6 +120,12 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
|
|||||||
|
|
||||||
## ❤️ Sponsors
|
## ❤️ Sponsors
|
||||||
|
|
||||||
|
Thank you to our sponsors who help make Axolotl possible:
|
||||||
|
|
||||||
|
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) - Modal lets you run
|
||||||
|
jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale,
|
||||||
|
fine-tune large language models, run protein folding simulations, and much more.
|
||||||
|
|
||||||
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
|
||||||
|
|
||||||
## 📜 License
|
## 📜 License
|
||||||
|
|||||||
@@ -19,7 +19,5 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
|
|||||||
--cov-append \
|
--cov-append \
|
||||||
--cov-report=xml:multigpu-coverage.xml
|
--cov-report=xml:multigpu-coverage.xml
|
||||||
|
|
||||||
# Upload coverage to Codecov if CODECOV_TOKEN is available
|
# Upload coverage to Codecov
|
||||||
if [ -n "$CODECOV_TOKEN" ]; then
|
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
|
||||||
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -9,15 +9,13 @@ ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
|||||||
EXPOSE 8888
|
EXPOSE 8888
|
||||||
EXPOSE 22
|
EXPOSE 22
|
||||||
|
|
||||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
COPY scripts/cloud-entrypoint-term.sh /root/cloud-entrypoint.sh
|
||||||
COPY scripts/motd /etc/motd
|
COPY scripts/motd /etc/motd
|
||||||
|
|
||||||
RUN pip install jupyterlab notebook ipywidgets && \
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
jupyter lab clean
|
jupyter lab clean
|
||||||
RUN apt update && \
|
RUN apt install --yes --no-install-recommends openssh-server tmux sudo && \
|
||||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
|
pip3 install -U --no-cache-dir grpcio ray[default]==2.9.3 && \
|
||||||
rm -rf /var/cache/apt/archives && \
|
|
||||||
rm -rf /var/lib/apt/lists/* && \
|
|
||||||
mkdir -p ~/.ssh && \
|
mkdir -p ~/.ssh && \
|
||||||
chmod 700 ~/.ssh && \
|
chmod 700 ~/.ssh && \
|
||||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||||
|
|||||||
@@ -136,7 +136,3 @@ description: Frequently asked questions
|
|||||||
> dynamic: false
|
> dynamic: false
|
||||||
> mode: max-autotune-no-cudagraphs
|
> mode: max-autotune-no-cudagraphs
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
|
|
||||||
|
|
||||||
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
|
|
||||||
|
|||||||
@@ -124,13 +124,10 @@ For providers supporting Docker:
|
|||||||
|
|
||||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
- Use `axolotlai/axolotl-cloud:main-latest`
|
||||||
- Available on:
|
- Available on:
|
||||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||||
- [PRIME Intellect](https://app.primeintellect.ai/dashboard/create-cluster?image=axolotl&location=Cheapest&security=Cheapest&show_spot=true)
|
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl)
|
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
||||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
|
||||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
|
||||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
|
||||||
|
|
||||||
### Google Colab {#sec-colab}
|
### Google Colab {#sec-colab}
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Set to a divisor (> 1) of the number of GPUs available
|
# Set to a divisor (> 1) of the number of GPUs available
|
||||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
context_parallel_size: 4 # Split sequences across 4 GPUs
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||||
@@ -30,7 +30,7 @@ heads_k_stride: 1
|
|||||||
ring_attn_func:
|
ring_attn_func:
|
||||||
```
|
```
|
||||||
|
|
||||||
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|
||||||
- With 8 GPUs, valid values would be 2, 4, or 8
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||||
- With 4 GPUs, valid values would be 2 or 4
|
- With 4 GPUs, valid values would be 2 or 4
|
||||||
@@ -66,7 +66,7 @@ sequence_len: 8192
|
|||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||||
heads_k_stride: 1
|
heads_k_stride: 1
|
||||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||||
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
|
|||||||
|
|
||||||
## Effect on Batch Size
|
## Effect on Batch Size
|
||||||
|
|
||||||
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
||||||
|
|
||||||
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
||||||
- The number of batches processed per step decreases
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
||||||
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
# Arctic Long Sequence Training (ALST)
|
|
||||||
|
|
||||||
Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
|
|
||||||
techniques. It is a combination of:
|
|
||||||
- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
|
|
||||||
- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage
|
|
||||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
|
||||||
|
|
||||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-3.1-8B
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: togethercomputer/Long-Data-Collections
|
|
||||||
type: completion
|
|
||||||
field: text
|
|
||||||
data_files:
|
|
||||||
- pretrain/rp_sub.jsonl.zst
|
|
||||||
- path: princeton-nlp/TextbookChapters
|
|
||||||
type: completion
|
|
||||||
field: chapter
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 500_000
|
|
||||||
min_sample_len: 200_000
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
tiled_mlp: true
|
|
||||||
sequence_parallel_degree: 8
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
activation_offloading: legacy
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 100
|
|
||||||
saves_per_epoch: 1
|
|
||||||
evals_per_epoch: 2
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
|
|
||||||
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-3.1-8B
|
|
||||||
# Automatically upload checkpoint and final model to HF
|
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: togethercomputer/Long-Data-Collections
|
|
||||||
type: completion
|
|
||||||
field: text
|
|
||||||
data_files:
|
|
||||||
- pretrain/rp_sub.jsonl.zst
|
|
||||||
- path: princeton-nlp/TextbookChapters
|
|
||||||
type: completion
|
|
||||||
field: chapter
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.0
|
|
||||||
output_dir: ./outputs/out
|
|
||||||
|
|
||||||
sequence_len: 500_000
|
|
||||||
min_sample_len: 200_000
|
|
||||||
sample_packing: true
|
|
||||||
|
|
||||||
tiled_mlp: true
|
|
||||||
context_parallel_size: 8
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 1
|
|
||||||
num_epochs: 1
|
|
||||||
optimizer: adamw_torch_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 2e-5
|
|
||||||
|
|
||||||
bf16: auto
|
|
||||||
tf32: true
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
activation_offloading: legacy
|
|
||||||
|
|
||||||
resume_from_checkpoint:
|
|
||||||
logging_steps: 1
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 100
|
|
||||||
saves_per_epoch: 1
|
|
||||||
evals_per_epoch: 2
|
|
||||||
weight_decay: 0.0
|
|
||||||
special_tokens:
|
|
||||||
pad_token: <|end_of_text|>
|
|
||||||
|
|
||||||
fsdp_version: 2
|
|
||||||
fsdp_config:
|
|
||||||
offload_params: false # offloading is currently not compatible with SP + torchao optimizer
|
|
||||||
state_dict_type: SHARDED_STATE_DICT
|
|
||||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
||||||
reshard_after_forward: true
|
|
||||||
|
|
||||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
|
||||||
@@ -9,6 +9,7 @@ liger_rms_norm: true
|
|||||||
liger_glu_activation: true
|
liger_glu_activation: true
|
||||||
liger_fused_linear_cross_entropy: true
|
liger_fused_linear_cross_entropy: true
|
||||||
|
|
||||||
|
|
||||||
chat_template: llama3
|
chat_template: llama3
|
||||||
datasets:
|
datasets:
|
||||||
- path: mlabonne/FineTome-100k
|
- path: mlabonne/FineTome-100k
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
|
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
# Currently, we don't support dropout with our custom Triton kernels
|
# Currently, we don't support dropout with our custom Triton kernels
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ packaging==23.2
|
|||||||
|
|
||||||
huggingface_hub>=0.33.0
|
huggingface_hub>=0.33.0
|
||||||
peft==0.16.0
|
peft==0.16.0
|
||||||
transformers==4.54.0
|
transformers @ git+https://github.com/huggingface/transformers.git@82603b6cc284dbdf2b7a7cf070feb6a2c3bb53cf
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.9.0
|
accelerate @ git+https://github.com/SalmanMohammadi/accelerate.git@device_mesh_parallelism_config
|
||||||
datasets==4.0.0
|
datasets==4.0.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.19.1
|
trl==0.19.1
|
||||||
hf_xet==1.1.5
|
hf_xet==1.1.2
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
@@ -62,10 +62,10 @@ langdetect==1.0.9
|
|||||||
immutabledict==4.2.0
|
immutabledict==4.2.0
|
||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.12.0
|
torchao==0.10.0
|
||||||
schedulefree==1.4.1
|
schedulefree==1.4.1
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.8.3
|
mistral-common==1.7.0
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
|
|
||||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
|
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
|
||||||
|
|
||||||
Need help with your post-training workloads? Reach out us at contact@axolotl.ai for assistance.
|
|
||||||
|
|
||||||
```
|
```
|
||||||
cd /workspace
|
cd /workspace
|
||||||
rm -rf /workspace/axolotl
|
rm -rf /workspace/axolotl
|
||||||
|
|||||||
16
setup.py
16
setup.py
@@ -68,10 +68,9 @@ def parse_requirements(extras_require_map):
|
|||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.append("xformers==0.0.30")
|
_install_requires.append("xformers==0.0.30")
|
||||||
# vllm 0.9.x is incompatible with latest transformers
|
|
||||||
extras_require_map.pop("vllm")
|
|
||||||
else:
|
else:
|
||||||
_install_requires.append("xformers==0.0.31")
|
_install_requires.append("xformers==0.0.31.post1")
|
||||||
|
extras_require_map["vllm"] = ["vllm>=0.9.0"]
|
||||||
elif (major, minor) >= (2, 6):
|
elif (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers==0.0.29.post3")
|
_install_requires.append("xformers==0.0.29.post3")
|
||||||
@@ -85,9 +84,7 @@ def parse_requirements(extras_require_map):
|
|||||||
else:
|
else:
|
||||||
_install_requires.append("xformers>=0.0.28.post3")
|
_install_requires.append("xformers>=0.0.28.post3")
|
||||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||||
extras_require_map.pop("vllm")
|
|
||||||
elif (major, minor) >= (2, 4):
|
elif (major, minor) >= (2, 4):
|
||||||
extras_require_map.pop("vllm")
|
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers>=0.0.27")
|
_install_requires.append("xformers>=0.0.27")
|
||||||
@@ -117,10 +114,10 @@ def get_package_version():
|
|||||||
|
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"flash-attn": ["flash-attn==2.8.2"],
|
"flash-attn": ["flash-attn==2.8.0.post2"],
|
||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.8.2",
|
"flash-attn==2.8.0.post2",
|
||||||
"ring-flash-attn>=0.1.7",
|
"ring-flash-attn>=0.1.5",
|
||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
@@ -154,12 +151,13 @@ extras_require = {
|
|||||||
"ray[train]",
|
"ray[train]",
|
||||||
],
|
],
|
||||||
"vllm": [
|
"vllm": [
|
||||||
"vllm==0.10.0",
|
"vllm==0.7.2",
|
||||||
],
|
],
|
||||||
"llmcompressor": [
|
"llmcompressor": [
|
||||||
"llmcompressor==0.5.1",
|
"llmcompressor==0.5.1",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
extras_require
|
extras_require
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
|||||||
load_in_8bit=False,
|
load_in_8bit=False,
|
||||||
load_in_4bit=False,
|
load_in_4bit=False,
|
||||||
flash_attention=False,
|
flash_attention=False,
|
||||||
sequence_parallel_degree=None,
|
context_parallel_size=None,
|
||||||
deepspeed=None,
|
deepspeed=None,
|
||||||
fsdp=None,
|
fsdp=None,
|
||||||
fsdp_config=None,
|
fsdp_config=None,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import torch
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
@@ -434,8 +435,18 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||||
|
use_configured_state = True
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
use_configured_state = self.cfg.accelerator_config.pop(
|
||||||
|
"use_configured_state", use_configured_state
|
||||||
|
)
|
||||||
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
|
use_configured_state=use_configured_state, **self.cfg.accelerator_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
training_args_kwargs["accelerator_config"] = AcceleratorConfig(
|
||||||
|
use_configured_state=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||||
if self.cfg.activation_offloading is True:
|
if self.cfg.activation_offloading is True:
|
||||||
@@ -500,7 +511,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||||
|
|
||||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||||
training_args_kwargs["average_tokens_across_devices"] = False
|
|
||||||
|
|
||||||
if self.cfg.eval_batch_size:
|
if self.cfg.eval_batch_size:
|
||||||
training_args_kwargs["per_device_eval_batch_size"] = (
|
training_args_kwargs["per_device_eval_batch_size"] = (
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.rl is RLType.GRPO:
|
if self.cfg.rl is RLType.GRPO:
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||||
)
|
)
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
|
|
||||||
|
|||||||
@@ -82,8 +82,8 @@ class GRPOStrategy:
|
|||||||
grpo_args_kwargs["log_completions"] = trl.log_completions
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
grpo_args_kwargs["num_completions_to_print"] = trl.num_completions_to_print
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree > 1:
|
if cfg.context_parallel_size > 1:
|
||||||
grpo_args_kwargs["sequence_parallel_degree"] = cfg.sequence_parallel_degree
|
grpo_args_kwargs["context_parallel_size"] = cfg.context_parallel_size
|
||||||
|
|
||||||
if trl.reward_weights:
|
if trl.reward_weights:
|
||||||
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|||||||
@@ -13,4 +13,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
|||||||
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||||
"""Axolotl GRPO Config for GRPO training"""
|
"""Axolotl GRPO Config for GRPO training"""
|
||||||
|
|
||||||
sequence_parallel_degree: int | None = None
|
context_parallel_size: int | None = None
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
- Data is properly distributed across SP groups.
|
- Data is properly distributed across SP groups.
|
||||||
|
|
||||||
In the table below, the values represent dataset indices. Each SP group has
|
In the table below, the values represent dataset indices. Each SP group has
|
||||||
`sequence_parallel_degree = 2` GPUs working together on the same data. There are 2
|
`context_parallel_size = 2` GPUs working together on the same data. There are 2
|
||||||
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
SP groups (SP0 and SP1), with `world_size = 4` total GPUs.
|
||||||
|
|
||||||
Sequence Parallel Groups
|
Sequence Parallel Groups
|
||||||
@@ -45,7 +45,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
rank: Rank of current process.
|
rank: Rank of current process.
|
||||||
batch_size: Number of samples per batch.
|
batch_size: Number of samples per batch.
|
||||||
repeat_count: How many times to repeat the full sampling process.
|
repeat_count: How many times to repeat the full sampling process.
|
||||||
sequence_parallel_degree: Number of ranks in a sequence parallel group.
|
context_parallel_size: Number of ranks in a sequence parallel group.
|
||||||
shuffle: Whether to shuffle the dataset.
|
shuffle: Whether to shuffle the dataset.
|
||||||
seed: Random seed for shuffling.
|
seed: Random seed for shuffling.
|
||||||
drop_last: Whether to drop the last incomplete batch.
|
drop_last: Whether to drop the last incomplete batch.
|
||||||
@@ -59,7 +59,7 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
rank: int,
|
rank: int,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
repeat_count: int = 1,
|
repeat_count: int = 1,
|
||||||
sequence_parallel_degree: int = 1,
|
context_parallel_size: int = 1,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
drop_last: bool = False,
|
drop_last: bool = False,
|
||||||
@@ -77,9 +77,9 @@ class SequenceParallelRepeatRandomSampler(Sampler):
|
|||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|
||||||
# Sequence parallelism parameters
|
# Sequence parallelism parameters
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
self.context_parallel_size = context_parallel_size
|
||||||
self.num_sp_groups = world_size // sequence_parallel_degree
|
self.num_sp_groups = world_size // context_parallel_size
|
||||||
self.sp_group_id = rank // sequence_parallel_degree
|
self.sp_group_id = rank // context_parallel_size
|
||||||
|
|
||||||
# Adjust dataset size for distributed sampling
|
# Adjust dataset size for distributed sampling
|
||||||
self.num_samples = len(self.dataset)
|
self.num_samples = len(self.dataset)
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
|
|
||||||
# Get number of SP groups (number of processes divided by SP degree)
|
# Get number of SP groups (number of processes divided by SP degree)
|
||||||
num_processes = self.accelerator.num_processes
|
num_processes = self.accelerator.num_processes
|
||||||
num_sp_groups = num_processes // self.args.sequence_parallel_degree
|
num_sp_groups = num_processes // self.args.context_parallel_size
|
||||||
|
|
||||||
# Calculate batch size per SP group (not per process)
|
# Calculate batch size per SP group (not per process)
|
||||||
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups
|
||||||
@@ -130,7 +130,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
|
|
||||||
if self.num_generations not in possible_values:
|
if self.num_generations not in possible_values:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), "
|
f"With sequence parallelism (degree {self.args.context_parallel_size}), "
|
||||||
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) "
|
||||||
f"must be evenly divisible by the number of generations per prompt "
|
f"must be evenly divisible by the number of generations per prompt "
|
||||||
f"({self.num_generations}). Given the current eval batch size, "
|
f"({self.num_generations}). Given the current eval batch size, "
|
||||||
@@ -167,9 +167,9 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
batch_size=effective_batch_size
|
batch_size=effective_batch_size
|
||||||
// self.num_generations
|
// self.num_generations
|
||||||
// self.args.sequence_parallel_degree,
|
// self.args.context_parallel_size,
|
||||||
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
repeat_count=self.num_iterations * self.args.gradient_accumulation_steps,
|
||||||
sequence_parallel_degree=self.args.sequence_parallel_degree,
|
context_parallel_size=self.args.context_parallel_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
seed=self.args.seed,
|
seed=self.args.seed,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
@@ -235,7 +235,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
# TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation
|
||||||
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
# if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e.,
|
||||||
# slice each batch along the sequence dimension).
|
# slice each batch along the sequence dimension).
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
# Otherwise prepare with accelerator
|
# Otherwise prepare with accelerator
|
||||||
@@ -308,18 +308,18 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||||
all_prompts_text = gather_object(prompts_text)
|
all_prompts_text = gather_object(prompts_text)
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
# Calculate sequence parallel group information
|
# Calculate sequence parallel group information
|
||||||
world_size = self.accelerator.num_processes
|
world_size = self.accelerator.num_processes
|
||||||
sequence_parallel_degree = self.args.sequence_parallel_degree
|
context_parallel_size = self.args.context_parallel_size
|
||||||
num_sp_groups = world_size // sequence_parallel_degree
|
num_sp_groups = world_size // context_parallel_size
|
||||||
|
|
||||||
# Since processes in the same SP group have the same prompts, we need to ensure
|
# Since processes in the same SP group have the same prompts, we need to ensure
|
||||||
# we only take one copy of each prompt from each SP group
|
# we only take one copy of each prompt from each SP group
|
||||||
ordered_set_of_prompts = []
|
ordered_set_of_prompts = []
|
||||||
for sp_group_id in range(num_sp_groups):
|
for sp_group_id in range(num_sp_groups):
|
||||||
# Get the first process from each SP group (typically the group leader)
|
# Get the first process from each SP group (typically the group leader)
|
||||||
group_leader_rank = sp_group_id * sequence_parallel_degree
|
group_leader_rank = sp_group_id * context_parallel_size
|
||||||
|
|
||||||
# Extract prompts from this SP group, accounting for num_generations duplicates
|
# Extract prompts from this SP group, accounting for num_generations duplicates
|
||||||
# We only need prompts from one rank in each SP group
|
# We only need prompts from one rank in each SP group
|
||||||
@@ -335,7 +335,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||||
# prompt individually.
|
# prompt individually.
|
||||||
ordered_set_of_prompts = all_prompts_text[
|
ordered_set_of_prompts = all_prompts_text[
|
||||||
:: self.num_generations * self.args.sequence_parallel_degree
|
:: self.num_generations * self.args.context_parallel_size
|
||||||
]
|
]
|
||||||
|
|
||||||
with profiling_context(self, "vLLM.generate"):
|
with profiling_context(self, "vLLM.generate"):
|
||||||
@@ -352,14 +352,14 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
completion_ids = [None] * (
|
completion_ids = [None] * (
|
||||||
len(all_prompts_text) // self.args.sequence_parallel_degree
|
len(all_prompts_text) // self.args.context_parallel_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Broadcast the completions from the main process to all processes
|
# Broadcast the completions from the main process to all processes
|
||||||
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
||||||
|
|
||||||
# Determine the appropriate slice based on sequence parallelism
|
# Determine the appropriate slice based on sequence parallelism
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||||
|
|
||||||
@@ -583,7 +583,7 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
|||||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||||
|
|
||||||
# Slice to keep only the local part of the data
|
# Slice to keep only the local part of the data
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.context_parallel_size > 1:
|
||||||
# Calculate SP group ID (which group of ranks this rank belongs to)
|
# Calculate SP group ID (which group of ranks this rank belongs to)
|
||||||
sp_group_id = self.accelerator.process_index // self.local_world_size
|
sp_group_id = self.accelerator.process_index // self.local_world_size
|
||||||
|
|
||||||
|
|||||||
@@ -4,22 +4,13 @@ Trainer mixin for activation checkpointing w offloading
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from peft import PeftModel
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||||
apply_activation_checkpointing,
|
apply_activation_checkpointing,
|
||||||
)
|
)
|
||||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||||
from transformers import GradientCheckpointingLayer, Trainer
|
from transformers import GradientCheckpointingLayer, Trainer
|
||||||
from trl.models.activation_offloading import (
|
from trl.models.activation_offloading import get_act_offloading_ctx_manager
|
||||||
NoOpManager,
|
|
||||||
OffloadActivations,
|
|
||||||
get_act_offloading_ctx_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationOffloadingMixin(Trainer):
|
class ActivationOffloadingMixin(Trainer):
|
||||||
@@ -30,14 +21,9 @@ class ActivationOffloadingMixin(Trainer):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.args.activation_offloading:
|
if self.args.activation_offloading:
|
||||||
if isinstance(self.model, PeftModel):
|
self.activation_offload_context = get_act_offloading_ctx_manager(
|
||||||
self.activation_offload_context = get_lora_act_offloading_ctx_manager(
|
self.model, use_streams=True
|
||||||
self.model, use_streams=True
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.activation_offload_context = get_act_offloading_ctx_manager(
|
|
||||||
self.model, use_streams=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.activation_offload_context = contextlib.nullcontext()
|
self.activation_offload_context = contextlib.nullcontext()
|
||||||
|
|
||||||
@@ -49,169 +35,3 @@ class ActivationOffloadingMixin(Trainer):
|
|||||||
def ac_wrap_hf_model(model: nn.Module, **kwargs):
|
def ac_wrap_hf_model(model: nn.Module, **kwargs):
|
||||||
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
||||||
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
|
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_lora_act_offloading_ctx_manager(
|
|
||||||
model: nn.Module,
|
|
||||||
use_pin_memory: bool = True,
|
|
||||||
use_streams: bool = True,
|
|
||||||
min_offload_size: int = 1024,
|
|
||||||
max_fwd_stash_size: int = 5,
|
|
||||||
warn_if_no_head: bool = True,
|
|
||||||
) -> OffloadActivations:
|
|
||||||
"""
|
|
||||||
Returns the activation offloading context manager for the model. All but the last output Linear in every step will
|
|
||||||
be offloaded.
|
|
||||||
|
|
||||||
If activation offloading is enabled, we return the OffloadActivations context manager. If activation offloading is
|
|
||||||
disabled, we return a NoOpManager context manager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (`nn.Module`):
|
|
||||||
Model to wrap with the activation offloading context manager.
|
|
||||||
use_pin_memory (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
|
|
||||||
be moved back onto GPU more quickly but is a limited resource.
|
|
||||||
use_streams (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to use streams for performance optimization where the communications get overlapped with the
|
|
||||||
computation. Requires a torch build after torch-2.5.0.
|
|
||||||
min_offload_size (`int`, *optional*, defaults to `1024`):
|
|
||||||
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
|
|
||||||
do not want to waste bandwidth and resources moving it to CPU and back.
|
|
||||||
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
|
|
||||||
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
|
|
||||||
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
|
|
||||||
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
|
|
||||||
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
|
|
||||||
runtime.
|
|
||||||
warn_if_no_head (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output
|
|
||||||
head is detected.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`contextlib.ContextDecorator`:
|
|
||||||
Activation offloading context manager for the model.
|
|
||||||
"""
|
|
||||||
# pylint: disable=unnecessary-dunder-call
|
|
||||||
activations_handling_ctx = OffloadActivations(
|
|
||||||
use_pin_memory=use_pin_memory,
|
|
||||||
use_streams=use_streams,
|
|
||||||
min_offload_size=min_offload_size,
|
|
||||||
max_fwd_stash_size=max_fwd_stash_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Below is our hack to disable offloading the last output Linear in every
|
|
||||||
# step, as the cost for offloading the activation and then soon after bringing
|
|
||||||
# it back is expensive.
|
|
||||||
output_head_detected = False
|
|
||||||
noop_ctx = NoOpManager()
|
|
||||||
|
|
||||||
# Try to get the actual model if it's wrapped
|
|
||||||
unwrapped_model = model
|
|
||||||
if hasattr(unwrapped_model, "module"):
|
|
||||||
unwrapped_model = unwrapped_model.module
|
|
||||||
# check for PEFT models
|
|
||||||
if hasattr(unwrapped_model, "base_model") and hasattr(
|
|
||||||
unwrapped_model, "peft_config"
|
|
||||||
):
|
|
||||||
unwrapped_model = unwrapped_model.base_model
|
|
||||||
|
|
||||||
# Check for different types of output heads
|
|
||||||
if hasattr(unwrapped_model, "output"):
|
|
||||||
if isinstance(unwrapped_model.output, nn.Module):
|
|
||||||
unwrapped_model.output.register_forward_pre_hook(
|
|
||||||
lambda *args: noop_ctx.__enter__()
|
|
||||||
)
|
|
||||||
unwrapped_model.output.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
output_head_detected = True
|
|
||||||
elif hasattr(unwrapped_model.output, "linear") and isinstance(
|
|
||||||
unwrapped_model.output.linear, nn.Module
|
|
||||||
):
|
|
||||||
unwrapped_model.output.linear.register_forward_pre_hook(
|
|
||||||
lambda *args: noop_ctx.__enter__()
|
|
||||||
)
|
|
||||||
unwrapped_model.output.linear.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
output_head_detected = True
|
|
||||||
|
|
||||||
# Check for HuggingFace model output heads
|
|
||||||
elif hasattr(unwrapped_model, "lm_head"):
|
|
||||||
unwrapped_model.lm_head.register_forward_pre_hook(
|
|
||||||
lambda *args: noop_ctx.__enter__()
|
|
||||||
)
|
|
||||||
unwrapped_model.lm_head.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
output_head_detected = True
|
|
||||||
|
|
||||||
# Check for decoder-based models
|
|
||||||
elif hasattr(unwrapped_model, "decoder"):
|
|
||||||
decoder = unwrapped_model.decoder
|
|
||||||
if hasattr(decoder, "output"):
|
|
||||||
decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
|
||||||
decoder.output.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
output_head_detected = True
|
|
||||||
# Some models have lm_head in the decoder
|
|
||||||
elif hasattr(decoder, "lm_head"):
|
|
||||||
decoder.lm_head.register_forward_pre_hook(
|
|
||||||
lambda *args: noop_ctx.__enter__()
|
|
||||||
)
|
|
||||||
decoder.lm_head.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
output_head_detected = True
|
|
||||||
|
|
||||||
# Check for transformer models with final layer norm
|
|
||||||
elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(
|
|
||||||
unwrapped_model, "ln_f"
|
|
||||||
):
|
|
||||||
final_norm = (
|
|
||||||
getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f
|
|
||||||
)
|
|
||||||
final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
|
||||||
final_norm.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
output_head_detected = True
|
|
||||||
|
|
||||||
# Check for models with head module
|
|
||||||
elif hasattr(unwrapped_model, "head") and isinstance(
|
|
||||||
unwrapped_model.head, nn.Module
|
|
||||||
):
|
|
||||||
unwrapped_model.head.register_forward_pre_hook(
|
|
||||||
lambda *args: noop_ctx.__enter__()
|
|
||||||
)
|
|
||||||
unwrapped_model.head.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
output_head_detected = True
|
|
||||||
|
|
||||||
if not output_head_detected and warn_if_no_head:
|
|
||||||
LOG.warning(
|
|
||||||
"During activation offloading, no output head was detected. If your model has an output head, it will be "
|
|
||||||
"offloaded. This usually greatly slows training, given the large vocabulary size. To change this "
|
|
||||||
"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by "
|
|
||||||
"passing `warn_if_no_head=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, module in unwrapped_model.named_modules():
|
|
||||||
# Disable offloading for any Liger modules
|
|
||||||
if "liger" in name.lower():
|
|
||||||
module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
|
||||||
module.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
# disable offloading for any submodules to fix LoRA training
|
|
||||||
if name.endswith("._checkpoint_wrapped_module"):
|
|
||||||
for _, sub_module in module.named_modules():
|
|
||||||
sub_module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
|
||||||
sub_module.register_forward_hook(
|
|
||||||
lambda *args: noop_ctx.__exit__(), always_call=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return activations_handling_ctx
|
|
||||||
|
|||||||
@@ -13,9 +13,11 @@ class CheckpointSaveMixin(Trainer):
|
|||||||
def _save_optimizer_and_scheduler(self, output_dir):
|
def _save_optimizer_and_scheduler(self, output_dir):
|
||||||
try:
|
try:
|
||||||
super()._save_optimizer_and_scheduler(output_dir)
|
super()._save_optimizer_and_scheduler(output_dir)
|
||||||
except NotImplementedError as exc:
|
except (NotImplementedError, KeyError) as exc:
|
||||||
LOG.warning(
|
# TODO: fix fsdp2 optimizer saving
|
||||||
|
LOG.warning_once(
|
||||||
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
|
||||||
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
"Optimizer and scheduler states were not saved - resuming from checkpoints "
|
||||||
"for this training run will not be possible."
|
"for this training run will not be possible.",
|
||||||
|
main_process_only=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
Module for handling LIGER input arguments.
|
Module for handling LIGER input arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -30,13 +28,13 @@ class LigerArgs(BaseModel):
|
|||||||
Input args for LIGER.
|
Input args for LIGER.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
liger_rope: Optional[bool] = None
|
liger_rope: bool | None = None
|
||||||
liger_rms_norm: Optional[bool] = None
|
liger_rms_norm: bool | None = None
|
||||||
liger_layer_norm: Optional[bool] = None
|
liger_layer_norm: bool | None = None
|
||||||
liger_swiglu: Optional[bool] = None
|
liger_swiglu: bool | None = None
|
||||||
liger_glu_activation: Optional[bool] = None
|
liger_glu_activation: bool | None = None
|
||||||
liger_cross_entropy: Optional[bool] = None
|
liger_cross_entropy: bool | None = None
|
||||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
liger_fused_linear_cross_entropy: bool | None = None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -57,12 +55,18 @@ class LigerArgs(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_tiled_mlp_conflict(cls, data):
|
def check_tiled_mlp_conflict(cls, data):
|
||||||
if (
|
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
|
||||||
data.get("liger_glu_activation") is True
|
|
||||||
and data.get("tiled_mlp") is True
|
|
||||||
and not data.get("tiled_mlp_use_original_mlp")
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_liger_rms_norm_tensor_parallel(cls, data):
|
||||||
|
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"`liger_rms_norm` is incompatible with tensor parallelism, "
|
||||||
|
"see https://github.com/linkedin/Liger-Kernel/issues/826"
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -102,8 +102,8 @@ def matmul_lora(
|
|||||||
del W
|
del W
|
||||||
|
|
||||||
if A is not None:
|
if A is not None:
|
||||||
A, B = A.t().to(dtype), B.t().to(dtype)
|
A, B = A.t(), B.t()
|
||||||
out += (X @ A) @ (s * B)
|
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
||||||
|
|
||||||
return out.view(batch, seq_len, -1) if reshape else out
|
return out.view(batch, seq_len, -1) if reshape else out
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ import peft
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_utils
|
import transformers.modeling_utils
|
||||||
from accelerate import init_empty_weights
|
from accelerate import PartialState, init_empty_weights
|
||||||
|
from accelerate.utils.dataclasses import ParallelismConfig
|
||||||
from peft import (
|
from peft import (
|
||||||
PeftConfig,
|
PeftConfig,
|
||||||
PeftMixedModel,
|
PeftMixedModel,
|
||||||
@@ -51,6 +52,7 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
get_device_count,
|
get_device_count,
|
||||||
get_device_type,
|
get_device_type,
|
||||||
|
get_world_size,
|
||||||
)
|
)
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||||
@@ -162,7 +164,6 @@ class ModelLoader:
|
|||||||
|
|
||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
@@ -183,6 +184,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _apply_pre_model_load_setup(self):
|
def _apply_pre_model_load_setup(self):
|
||||||
"""Apply patches and setup configurations before model loading."""
|
"""Apply patches and setup configurations before model loading."""
|
||||||
|
self._set_parallel_config()
|
||||||
self._set_auto_model_loader()
|
self._set_auto_model_loader()
|
||||||
self._set_device_map_config()
|
self._set_device_map_config()
|
||||||
if self.cfg.revision_of_model:
|
if self.cfg.revision_of_model:
|
||||||
@@ -390,6 +392,52 @@ class ModelLoader:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _set_parallel_config(self):
|
||||||
|
"""Set parallelism configuration (DP, FSDP, TP, CP) in PartialState/Accelerator"""
|
||||||
|
dp_replicate_size = get_world_size()
|
||||||
|
pc_kwargs = {}
|
||||||
|
if self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1:
|
||||||
|
pc_kwargs["dp_shard_size"] = self.cfg.dp_shard_size
|
||||||
|
dp_replicate_size = dp_replicate_size // self.cfg.dp_shard_size
|
||||||
|
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1:
|
||||||
|
pc_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
|
dp_replicate_size = dp_replicate_size // self.cfg.tensor_parallel_size
|
||||||
|
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
||||||
|
pc_kwargs["cp_size"] = self.cfg.context_parallel_size
|
||||||
|
dp_replicate_size = dp_replicate_size // self.cfg.context_parallel_size
|
||||||
|
if dp_replicate_size > 1:
|
||||||
|
pc_kwargs["dp_replicate_size"] = dp_replicate_size
|
||||||
|
|
||||||
|
parallelism_config = ParallelismConfig(
|
||||||
|
**pc_kwargs,
|
||||||
|
)
|
||||||
|
mesh_dim_names, mesh_shape = parallelism_config.get_mesh()
|
||||||
|
device_mesh = torch.distributed.init_device_mesh(
|
||||||
|
"cuda", mesh_shape, mesh_dim_names=mesh_dim_names
|
||||||
|
)
|
||||||
|
submeshes = [
|
||||||
|
tuple(parallelism_config.dp_dim_names),
|
||||||
|
tuple(parallelism_config.dp_shard_cp_dim_names),
|
||||||
|
tuple(parallelism_config.dp_cp_dim_names),
|
||||||
|
]
|
||||||
|
submesh_names = [
|
||||||
|
# create a submesh which is only used for distributing data across data parallel dims (no comms)
|
||||||
|
"dp",
|
||||||
|
# create a submesh which is used *just* for FSDP parameter gathering/scattering
|
||||||
|
# and gradients reduce-scattering
|
||||||
|
"dp_shard_cp",
|
||||||
|
# create a submesh which is used for correctly reducing loss across data replica/context parallel
|
||||||
|
"dp_cp",
|
||||||
|
]
|
||||||
|
for submesh, submesh_name in zip(submeshes, submesh_names):
|
||||||
|
if submesh:
|
||||||
|
device_mesh[submesh]._flatten( # pylint: disable=protected-access
|
||||||
|
submesh_name
|
||||||
|
)
|
||||||
|
|
||||||
|
PartialState().parallelism_config = parallelism_config
|
||||||
|
PartialState().device_mesh = device_mesh
|
||||||
|
|
||||||
def _set_auto_model_loader(self):
|
def _set_auto_model_loader(self):
|
||||||
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
"""Set `self.auto_model_loader`. Defaults to `transformers.AutoModelForCausalLM`
|
||||||
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
(set at `__init__`). When using a multimodal model, `self.auto_model_loader`
|
||||||
@@ -622,6 +670,14 @@ class ModelLoader:
|
|||||||
def _build_model(self) -> bool:
|
def _build_model(self) -> bool:
|
||||||
"""Load model, with load strategy depending on config."""
|
"""Load model, with load strategy depending on config."""
|
||||||
skip_move_to_device = False
|
skip_move_to_device = False
|
||||||
|
|
||||||
|
if self.cfg.tensor_parallel_size > 1:
|
||||||
|
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size
|
||||||
|
self.model_kwargs["tp_plan"] = "auto"
|
||||||
|
self.model_kwargs["device_mesh"] = PartialState().device_mesh
|
||||||
|
if "device_map" in self.model_kwargs:
|
||||||
|
del self.model_kwargs["device_map"] # not compatible with `tp_plan`
|
||||||
|
|
||||||
if self.is_fsdp_enabled:
|
if self.is_fsdp_enabled:
|
||||||
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
if self.cfg.fsdp_config.cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
|
|||||||
@@ -66,9 +66,6 @@ class PatchManager:
|
|||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_gemma3_conditional_generation_forward_patch()
|
self._apply_gemma3_conditional_generation_forward_patch()
|
||||||
self._apply_sequence_parallel_patches()
|
self._apply_sequence_parallel_patches()
|
||||||
|
|
||||||
def apply_post_plugin_pre_model_load_patches(self):
|
|
||||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
|
||||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
@@ -264,20 +261,18 @@ class PatchManager:
|
|||||||
|
|
||||||
def _apply_sequence_parallel_patches(self):
|
def _apply_sequence_parallel_patches(self):
|
||||||
"""Apply sequence parallelism patches."""
|
"""Apply sequence parallelism patches."""
|
||||||
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
if self.cfg.context_parallel_size and self.cfg.context_parallel_size > 1:
|
||||||
from axolotl.monkeypatch.ring_attn.patch import (
|
from axolotl.monkeypatch.ring_attn.patch import (
|
||||||
patch_prepare_data_loader,
|
patch_prepare_data_loader,
|
||||||
patch_prepare_device_mesh,
|
patch_prepare_device_mesh,
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_prepare_data_loader()
|
patch_prepare_data_loader()
|
||||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
patch_prepare_device_mesh(self.cfg.context_parallel_size, self.cfg.fsdp)
|
||||||
|
|
||||||
def _apply_tiled_mlp(self, model_type: str):
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
if self.cfg.tiled_mlp:
|
if self.cfg.tiled_mlp:
|
||||||
from axolotl.monkeypatch.tiled_mlp import (
|
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||||
patch_tiled_mlp,
|
|
||||||
)
|
|
||||||
|
|
||||||
patch_tiled_mlp(
|
patch_tiled_mlp(
|
||||||
model_type,
|
model_type,
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need the `auto_wrap_policy` original type to create a custom policy function for sharding
|
# We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding
|
||||||
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
# This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour
|
||||||
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy:
|
||||||
pass # auto_wrap_policy_type = "transformer"
|
pass # auto_wrap_policy_type = "transformer"
|
||||||
@@ -254,6 +254,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||||
|
"mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)],
|
||||||
}
|
}
|
||||||
|
|
||||||
model_has_params4bit = False
|
model_has_params4bit = False
|
||||||
|
|||||||
@@ -18,15 +18,10 @@ import transformers
|
|||||||
import transformers.modeling_flash_attention_utils
|
import transformers.modeling_flash_attention_utils
|
||||||
from ring_flash_attn import ring_flash_attn_func
|
from ring_flash_attn import ring_flash_attn_func
|
||||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
from transformers.modeling_flash_attention_utils import (
|
||||||
|
_flash_supports_window_size,
|
||||||
try:
|
is_flash_attn_greater_or_equal,
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
)
|
||||||
except ImportError:
|
|
||||||
from transformers.modeling_flash_attention_utils import (
|
|
||||||
_flash_supports_window_size as _flash_supports_window,
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
@@ -117,7 +112,7 @@ def create_flash_attn_forward_varlen_llama3(
|
|||||||
|
|
||||||
# Handle sliding window
|
# Handle sliding window
|
||||||
use_sliding_windows = (
|
use_sliding_windows = (
|
||||||
_flash_supports_window
|
_flash_supports_window_size
|
||||||
and sliding_window is not None
|
and sliding_window is not None
|
||||||
and key_states.shape[1] > sliding_window
|
and key_states.shape[1] > sliding_window
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -162,14 +162,14 @@ def create_ring_flash_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def register_ring_attn(
|
def register_ring_attn(
|
||||||
sequence_parallel_degree: int,
|
context_parallel_size: int,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
ring_attn_func: RingAttnFunc | None,
|
ring_attn_func: RingAttnFunc | None,
|
||||||
):
|
):
|
||||||
"""Create ring attention group and substitute flash attn with ring flash attn.
|
"""Create ring attention group and substitute flash attn with ring flash attn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence_parallel_degree: Sequence parallelism factor.
|
context_parallel_size: Sequence parallelism factor.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
`varlen_llama3` `ring_flash_attn` implementation.
|
`varlen_llama3` `ring_flash_attn` implementation.
|
||||||
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
ring_attn_func: `ring_flash_attn` ring attention implemention. If sample
|
||||||
@@ -182,25 +182,25 @@ def register_ring_attn(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Enabling ring attention sequence parallelism: "
|
"Enabling ring attention sequence parallelism: "
|
||||||
f"each sequence will be processed across {sequence_parallel_degree} GPUs"
|
f"each sequence will be processed across {context_parallel_size} GPUs"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sequence_parallel_degree <= world_size, (
|
assert context_parallel_size <= world_size, (
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
f"context_parallel_size ({context_parallel_size}) "
|
||||||
f"must be less than or equal to world_size ({world_size})"
|
f"must be less than or equal to world_size ({world_size})"
|
||||||
)
|
)
|
||||||
assert world_size % sequence_parallel_degree == 0, (
|
assert world_size % context_parallel_size == 0, (
|
||||||
f"sequence_parallel_degree ({sequence_parallel_degree}) "
|
f"context_parallel_size ({context_parallel_size}) "
|
||||||
f"must evenly divide world_size ({world_size})"
|
f"must evenly divide world_size ({world_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign ranks to sequence parallel groups
|
# Assign ranks to sequence parallel groups
|
||||||
group_assignments = {}
|
group_assignments = {}
|
||||||
for i in range(world_size // sequence_parallel_degree):
|
for i in range(world_size // context_parallel_size):
|
||||||
ring_attn_ranks = list(
|
ring_attn_ranks = list(
|
||||||
range(
|
range(
|
||||||
i * sequence_parallel_degree,
|
i * context_parallel_size,
|
||||||
(i + 1) * sequence_parallel_degree,
|
(i + 1) * context_parallel_size,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
|
||||||
@@ -299,12 +299,12 @@ def patch_prepare_data_loader():
|
|||||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
def patch_prepare_device_mesh(context_parallel_size: int, fsdp: bool = False):
|
||||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||||
that includes sequence parallelism with the specified degree.
|
that includes sequence parallelism with the specified degree.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence_parallel_degree: The degree of sequence parallelism to use.
|
context_parallel_size: The degree of sequence parallelism to use.
|
||||||
fsdp: Whether to use FSDP.
|
fsdp: Whether to use FSDP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -323,8 +323,8 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False)
|
|||||||
# Create device mesh with sequence parallelism
|
# Create device mesh with sequence parallelism
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
mesh_shape = (
|
mesh_shape = (
|
||||||
world_size // sequence_parallel_degree,
|
world_size // context_parallel_size,
|
||||||
sequence_parallel_degree,
|
context_parallel_size,
|
||||||
)
|
)
|
||||||
device_ids = list(range(world_size))
|
device_ids = list(range(world_size))
|
||||||
|
|
||||||
@@ -344,5 +344,5 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False)
|
|||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Successfully patched Accelerator._prepare_device_mesh "
|
"Successfully patched Accelerator._prepare_device_mesh "
|
||||||
f"with sequence_parallel_degree={sequence_parallel_degree}"
|
f"with context_parallel_size={context_parallel_size}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,12 +12,8 @@ from axolotl.utils.logging import get_logger
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import (
|
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
||||||
TiledMLP as DeepSpeedTiledMLP,
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Dynamically import the module and MLP class
|
# Dynamically import the module and MLP class
|
||||||
@@ -40,7 +36,6 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
||||||
|
|
||||||
def tiled_mlp_forward(self, x):
|
def tiled_mlp_forward(self, x):
|
||||||
# pylint: disable=protected-access
|
|
||||||
input_shape = x.shape
|
input_shape = x.shape
|
||||||
seqlen = input_shape[-2]
|
seqlen = input_shape[-2]
|
||||||
hidden = input_shape[-1]
|
hidden = input_shape[-1]
|
||||||
@@ -53,23 +48,14 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
else:
|
else:
|
||||||
num_shards = cfg_num_shards
|
num_shards = cfg_num_shards
|
||||||
|
|
||||||
if not self._compute_params:
|
if not self._compute_params: # pylint: disable=protected-access
|
||||||
self._compute_params = [p for p in self.parameters() if p.requires_grad]
|
self._compute_params = [ # pylint: disable=protected-access
|
||||||
|
p for p in self.parameters() if p.requires_grad
|
||||||
|
]
|
||||||
|
|
||||||
compute_params = self._compute_params
|
compute_params = self._compute_params # pylint: disable=protected-access
|
||||||
if not self._tiled_mlp_dist_impl:
|
|
||||||
if (
|
|
||||||
self._compute_params
|
|
||||||
and any(
|
|
||||||
hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group")
|
|
||||||
for p in self._compute_params
|
|
||||||
)
|
|
||||||
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
|
||||||
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
|
|
||||||
else:
|
|
||||||
self._tiled_mlp_dist_impl = TiledMLP
|
|
||||||
|
|
||||||
down_res = self._tiled_mlp_dist_impl.apply(
|
down_res = TiledMLP.apply(
|
||||||
mlp_forward,
|
mlp_forward,
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@@ -80,7 +66,6 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
|||||||
|
|
||||||
mlp_cls.forward = tiled_mlp_forward
|
mlp_cls.forward = tiled_mlp_forward
|
||||||
mlp_cls._compute_params = [] # pylint: disable=protected-access
|
mlp_cls._compute_params = [] # pylint: disable=protected-access
|
||||||
mlp_cls._tiled_mlp_dist_impl = None # pylint: disable=protected-access
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
|
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
|
||||||
main_process_only=True,
|
main_process_only=True,
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
"""
|
|
||||||
TiledMLP monkey patches
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .patch import (
|
|
||||||
patch_tiled_mlp,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"patch_tiled_mlp",
|
|
||||||
]
|
|
||||||
@@ -1,153 +0,0 @@
|
|||||||
"""
|
|
||||||
TiledMLP support for DDP, FSDP, and single GPU
|
|
||||||
"""
|
|
||||||
|
|
||||||
import threading
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class TiledMLP(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
TiledMLP implementation using gradient hooks
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
fn,
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
shards,
|
|
||||||
compute_params,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
ctx.fn = fn
|
|
||||||
ctx.self = self
|
|
||||||
ctx.shards = shards
|
|
||||||
ctx.compute_params = [p for p in compute_params if p.requires_grad]
|
|
||||||
ctx.save_for_backward(x)
|
|
||||||
|
|
||||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
|
||||||
with torch.no_grad():
|
|
||||||
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
|
||||||
output_unsharded = torch.cat(output_shards, dim=1)
|
|
||||||
|
|
||||||
return output_unsharded
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, *grads) -> torch.Tensor:
|
|
||||||
fn = ctx.fn
|
|
||||||
(x,) = ctx.saved_tensors
|
|
||||||
self = ctx.self
|
|
||||||
shards = ctx.shards
|
|
||||||
compute_params = ctx.compute_params
|
|
||||||
|
|
||||||
x_requires_grad = x.requires_grad
|
|
||||||
x = x.detach()
|
|
||||||
x.requires_grad_(x_requires_grad)
|
|
||||||
|
|
||||||
incoming_grad = grads[0]
|
|
||||||
x_grad = torch.zeros_like(x)
|
|
||||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
|
||||||
|
|
||||||
# Create a gradient accumulator for parameters
|
|
||||||
grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
|
|
||||||
|
|
||||||
shard_step = x_shards[0].numel()
|
|
||||||
for i, x_shard in enumerate(x_shards):
|
|
||||||
x_shard.requires_grad_(x_requires_grad)
|
|
||||||
|
|
||||||
shard_offset = i * shard_step
|
|
||||||
x_shard.grad = (
|
|
||||||
x_grad.view(-1)
|
|
||||||
.narrow(0, shard_offset, x_shard.numel())
|
|
||||||
.view_as(x_shard)
|
|
||||||
)
|
|
||||||
incoming_grad_shard = (
|
|
||||||
incoming_grad.view(-1)
|
|
||||||
.narrow(0, shard_offset, x_shard.numel())
|
|
||||||
.view_as(x_shard)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Install hooks for this shard
|
|
||||||
is_last_shard = i + 1 == shards
|
|
||||||
grad_accumulator.install_hooks(is_last_shard)
|
|
||||||
|
|
||||||
with torch.enable_grad():
|
|
||||||
output = fn(self, x_shard)
|
|
||||||
torch.autograd.backward(output, incoming_grad_shard)
|
|
||||||
|
|
||||||
# Clean up hooks
|
|
||||||
grad_accumulator.cleanup()
|
|
||||||
del grad_accumulator
|
|
||||||
|
|
||||||
return (None, None, x_grad, None, None)
|
|
||||||
|
|
||||||
|
|
||||||
class GradientAccumulator:
|
|
||||||
"""
|
|
||||||
Manual gradient accumulator for TiledMLP with configurable precision
|
|
||||||
Accumulates in specified dtype and rescales the gradient at the end
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params: List[torch.nn.Parameter],
|
|
||||||
total_shards: int,
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
):
|
|
||||||
self.params = params
|
|
||||||
self.total_shards = total_shards
|
|
||||||
self.grad_accumulation_dtype = dtype or torch.float32
|
|
||||||
self.accumulated_grads = {}
|
|
||||||
self.hooks = []
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.gradient_scale = 1.0 / total_shards
|
|
||||||
|
|
||||||
# Initialize accumulated gradients in the specified dtype
|
|
||||||
for param in self.params:
|
|
||||||
if param.grad is not None:
|
|
||||||
self.accumulated_grads[param] = param.grad.to(
|
|
||||||
self.grad_accumulation_dtype
|
|
||||||
)
|
|
||||||
param.grad = None
|
|
||||||
else:
|
|
||||||
self.accumulated_grads[param] = torch.zeros_like(
|
|
||||||
param, dtype=self.grad_accumulation_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
def install_hooks(self, is_last_shard: bool):
|
|
||||||
"""Install gradient hooks that accumulate gradients in higher precision"""
|
|
||||||
|
|
||||||
def create_hook(param):
|
|
||||||
def hook(grad):
|
|
||||||
with self.lock:
|
|
||||||
grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)
|
|
||||||
scaled_grad = grad_to_accum_dtype * self.gradient_scale
|
|
||||||
|
|
||||||
if param in self.accumulated_grads:
|
|
||||||
self.accumulated_grads[param] += scaled_grad
|
|
||||||
else:
|
|
||||||
self.accumulated_grads[param] = scaled_grad.clone()
|
|
||||||
|
|
||||||
# Only assign the averaged gradient on the last shard
|
|
||||||
if is_last_shard:
|
|
||||||
param.grad = self.accumulated_grads[param].to(param.dtype)
|
|
||||||
return param.grad
|
|
||||||
return None
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
# Install hooks on all parameters
|
|
||||||
for param in self.params:
|
|
||||||
if param.requires_grad:
|
|
||||||
hook = param.register_hook(create_hook(param))
|
|
||||||
self.hooks.append(hook)
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""Remove all installed hooks"""
|
|
||||||
for hook in self.hooks:
|
|
||||||
hook.remove()
|
|
||||||
self.hooks.clear()
|
|
||||||
del self.accumulated_grads
|
|
||||||
@@ -115,11 +115,8 @@ def setup_reference_model(
|
|||||||
LOG.debug("Passing model_ref: None to RL trainer")
|
LOG.debug("Passing model_ref: None to RL trainer")
|
||||||
model_ref = None # explicit setting to None
|
model_ref = None # explicit setting to None
|
||||||
else:
|
else:
|
||||||
reference_model: bool = True
|
|
||||||
if cfg.rl == RLType.GRPO and cfg.trl.beta == 0:
|
|
||||||
reference_model = False
|
|
||||||
# load the model again for model_ref/baseline
|
# load the model again for model_ref/baseline
|
||||||
model_loader = ModelLoader(cfg, tokenizer, reference_model=reference_model)
|
model_loader = ModelLoader(cfg, tokenizer, reference_model=True)
|
||||||
model_ref, _ = model_loader.load()
|
model_ref, _ = model_loader.load()
|
||||||
return model_ref
|
return model_ref
|
||||||
|
|
||||||
@@ -205,7 +202,7 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.sequence_parallel_degree > 1:
|
if cfg.context_parallel_size > 1:
|
||||||
models = [trainer.model]
|
models = [trainer.model]
|
||||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||||
models.append(trainer.ref_model)
|
models.append(trainer.ref_model)
|
||||||
@@ -213,7 +210,7 @@ def execute_training(
|
|||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
SequenceParallelContextManager(
|
SequenceParallelContextManager(
|
||||||
models=models,
|
models=models,
|
||||||
sequence_parallel_degree=cfg.sequence_parallel_degree,
|
context_parallel_size=cfg.context_parallel_size,
|
||||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
ring_attn_func=cfg.ring_attn_func,
|
ring_attn_func=cfg.ring_attn_func,
|
||||||
heads_k_stride=cfg.heads_k_stride,
|
heads_k_stride=cfg.heads_k_stride,
|
||||||
|
|||||||
@@ -27,11 +27,7 @@ from transformers import (
|
|||||||
TrainerState,
|
TrainerState,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import (
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
PREFIX_CHECKPOINT_DIR,
|
|
||||||
IntervalStrategy,
|
|
||||||
SaveStrategy,
|
|
||||||
)
|
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
@@ -867,16 +863,10 @@ class GCCallback(TrainerCallback):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def on_train_begin(
|
|
||||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
self._gc()
|
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
# pylint: disable=consider-using-in
|
if self.next_gc_on_begin_step == state.global_step:
|
||||||
if self.next_gc_on_begin_step == state.global_step or state.global_step == 0:
|
|
||||||
self._gc()
|
self._gc()
|
||||||
|
|
||||||
def on_step_end(
|
def on_step_end(
|
||||||
@@ -889,17 +879,6 @@ class GCCallback(TrainerCallback):
|
|||||||
self.next_gc_on_begin_step = state.global_step + 1
|
self.next_gc_on_begin_step = state.global_step + 1
|
||||||
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
||||||
self._gc()
|
self._gc()
|
||||||
elif (
|
|
||||||
args.save_strategy == SaveStrategy.STEPS
|
|
||||||
and state.save_steps > 0
|
|
||||||
and state.global_step % state.save_steps == 0
|
|
||||||
):
|
|
||||||
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
|
|
||||||
self._gc()
|
|
||||||
elif state.global_step >= state.max_steps:
|
|
||||||
if args.save_strategy == SaveStrategy.STEPS:
|
|
||||||
# gc on save steps in case anything is loaded to CPU RAM like offloaded tensors
|
|
||||||
self._gc()
|
|
||||||
|
|
||||||
def on_epoch_end(
|
def on_epoch_end(
|
||||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ class SequenceParallelContextManager:
|
|||||||
Args:
|
Args:
|
||||||
models: List of models to apply sequence parallelism to pre- and post- forward
|
models: List of models to apply sequence parallelism to pre- and post- forward
|
||||||
hooks.
|
hooks.
|
||||||
sequence_parallel_degree: Number of processes to split sequences over.
|
context_parallel_size: Number of processes to split sequences over.
|
||||||
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
gradient_accumulation_steps: Number of steps to accumulate gradients over.
|
||||||
ring_attn_func: Which ring attention function to use. Currently unused.
|
ring_attn_func: Which ring attention function to use. Currently unused.
|
||||||
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
heads_k_stride: Sequence parallelism K head stride size. Passed through to
|
||||||
@@ -179,14 +179,14 @@ class SequenceParallelContextManager:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
models: list[nn.Module],
|
models: list[nn.Module],
|
||||||
sequence_parallel_degree: int,
|
context_parallel_size: int,
|
||||||
gradient_accumulation_steps: int,
|
gradient_accumulation_steps: int,
|
||||||
ring_attn_func: RingAttnFunc,
|
ring_attn_func: RingAttnFunc,
|
||||||
heads_k_stride: int | None,
|
heads_k_stride: int | None,
|
||||||
gather_outputs: bool,
|
gather_outputs: bool,
|
||||||
):
|
):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.sequence_parallel_degree = sequence_parallel_degree
|
self.context_parallel_size = context_parallel_size
|
||||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||||
self.ring_attn_func = ring_attn_func
|
self.ring_attn_func = ring_attn_func
|
||||||
self.heads_k_stride = heads_k_stride
|
self.heads_k_stride = heads_k_stride
|
||||||
@@ -231,7 +231,7 @@ class SequenceParallelContextManager:
|
|||||||
def _register_ring_attn(self):
|
def _register_ring_attn(self):
|
||||||
# Initialize ring attn for sequence parallelism
|
# Initialize ring attn for sequence parallelism
|
||||||
register_ring_attn(
|
register_ring_attn(
|
||||||
sequence_parallel_degree=self.sequence_parallel_degree,
|
context_parallel_size=self.context_parallel_size,
|
||||||
heads_k_stride=self.heads_k_stride,
|
heads_k_stride=self.heads_k_stride,
|
||||||
ring_attn_func=self.ring_attn_func,
|
ring_attn_func=self.ring_attn_func,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -46,8 +46,7 @@ class FileLockLoader:
|
|||||||
def _increment_counter(self):
|
def _increment_counter(self):
|
||||||
"""Safely increment the process counter."""
|
"""Safely increment the process counter."""
|
||||||
if self.counter_path.exists():
|
if self.counter_path.exists():
|
||||||
counter_content = self.counter_path.read_text().strip()
|
count = int(self.counter_path.read_text().strip())
|
||||||
count = int(counter_content) if counter_content else 0
|
|
||||||
else:
|
else:
|
||||||
count = 0
|
count = 0
|
||||||
self.counter_path.write_text(str(count + 1))
|
self.counter_path.write_text(str(count + 1))
|
||||||
@@ -55,11 +54,10 @@ class FileLockLoader:
|
|||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Clean up ready flag when last process is done."""
|
"""Clean up ready flag when last process is done."""
|
||||||
with FileLock(str(self.lock_file_path)):
|
with FileLock(str(self.lock_file_path)):
|
||||||
counter_content = self.counter_path.read_text().strip()
|
count = int(self.counter_path.read_text().strip())
|
||||||
count = int(counter_content) if counter_content else 0
|
|
||||||
count -= 1
|
count -= 1
|
||||||
|
|
||||||
if count <= 0:
|
if count == 0:
|
||||||
# Last process cleans everything up
|
# Last process cleans everything up
|
||||||
self.ready_flag_path.unlink(missing_ok=True)
|
self.ready_flag_path.unlink(missing_ok=True)
|
||||||
self.counter_path.unlink(missing_ok=True)
|
self.counter_path.unlink(missing_ok=True)
|
||||||
|
|||||||
@@ -543,12 +543,6 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
|||||||
|
|
||||||
return ds.shuffle(seed=cfg.seed)
|
return ds.shuffle(seed=cfg.seed)
|
||||||
|
|
||||||
# If enabled, shuffle each dataset independently before merging.
|
|
||||||
# This allows curriculum learning strategies to be applied at the dataset level.
|
|
||||||
if cfg.shuffle_before_merging_datasets:
|
|
||||||
LOG.info("Shuffling each dataset individually before merging...")
|
|
||||||
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
|
||||||
|
|
||||||
LOG.info("Merging datasets...")
|
LOG.info("Merging datasets...")
|
||||||
merged_dataset = concatenate_datasets(datasets)
|
merged_dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
|
|||||||
@@ -179,12 +179,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "If false, the datasets will not be shuffled and will keep their original order in `datasets`. The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true."
|
"description": "If false, the datasets will not be shuffled and will keep their original order in `datasets`. The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
shuffle_before_merging_datasets: bool | None = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "If true, each dataset in `datasets` will be shuffled before merging. This allows curriculum learning strategies to be applied at the dataset level. Default is false."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
dataset_prepared_path: str | None = Field(
|
dataset_prepared_path: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -603,7 +597,7 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tiled_mlp_use_original_mlp: bool | None = Field(
|
tiled_mlp_use_original_mlp: bool | None = Field(
|
||||||
default=True,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
|
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
|
||||||
},
|
},
|
||||||
@@ -650,7 +644,19 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dp_shard_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of devices to shard across. If not set, will use all available devices."
|
||||||
|
},
|
||||||
|
)
|
||||||
sequence_parallel_degree: int | None = Field(
|
sequence_parallel_degree: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Deprecated: use `context_parallel_size` instead"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
context_parallel_size: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
|
||||||
|
|||||||
0
src/axolotl/utils/schemas/distributed.py
Normal file
0
src/axolotl/utils/schemas/distributed.py
Normal file
@@ -512,6 +512,19 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_tiled_mlp_deepspeed(cls, data):
|
||||||
|
capabilities = data.get("capabilities")
|
||||||
|
n_gpu = 0
|
||||||
|
if capabilities and capabilities.get("n_gpu", 0) >= 1:
|
||||||
|
n_gpu = capabilities.get("n_gpu", 0)
|
||||||
|
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
|
||||||
|
raise ValueError(
|
||||||
|
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
@@ -673,7 +686,7 @@ class RLValidationMixin:
|
|||||||
data.get("rl") == "grpo"
|
data.get("rl") == "grpo"
|
||||||
and data.get("trl", {})
|
and data.get("trl", {})
|
||||||
and data.get("trl").get("use_liger_loss")
|
and data.get("trl").get("use_liger_loss")
|
||||||
and data.get("sequence_parallel_degree", 1) > 1
|
and data.get("context_parallel_size", 1) > 1
|
||||||
):
|
):
|
||||||
raise ValueError("GRPO + SP + Liger not currently supported")
|
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||||
return data
|
return data
|
||||||
@@ -900,31 +913,30 @@ class OptimizationValidationMixin:
|
|||||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||||
if not data.get("deepspeed"):
|
if data.get("deepspeed"):
|
||||||
raise ValueError(
|
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
ds_config = json.load(ds_fin)
|
||||||
)
|
should_save = False
|
||||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
if "tensor_parallel" not in ds_config:
|
||||||
ds_config = json.load(ds_fin)
|
ds_config["tensor_parallel"] = {
|
||||||
should_save = False
|
"autotp_size": tensor_parallel_size
|
||||||
if "tensor_parallel" not in ds_config:
|
}
|
||||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
should_save = True
|
||||||
should_save = True
|
if (
|
||||||
if (
|
|
||||||
"gather_16bit_weights_on_model_save"
|
|
||||||
not in ds_config["zero_optimization"]
|
|
||||||
):
|
|
||||||
ds_config["zero_optimization"][
|
|
||||||
"gather_16bit_weights_on_model_save"
|
"gather_16bit_weights_on_model_save"
|
||||||
] = True
|
not in ds_config["zero_optimization"]
|
||||||
should_save = True
|
):
|
||||||
if should_save:
|
ds_config["zero_optimization"][
|
||||||
temp_dir = tempfile.mkdtemp()
|
"gather_16bit_weights_on_model_save"
|
||||||
with open(
|
] = True
|
||||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
should_save = True
|
||||||
) as ds_fout:
|
if should_save:
|
||||||
json.dump(ds_config, ds_fout, indent=4)
|
temp_dir = tempfile.mkdtemp()
|
||||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
with open(
|
||||||
|
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||||
|
) as ds_fout:
|
||||||
|
json.dump(ds_config, ds_fout, indent=4)
|
||||||
|
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -1091,10 +1103,16 @@ class ModelCompatibilityValidationMixin:
|
|||||||
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
|
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true` or `activation_offloading: legacy`"
|
||||||
)
|
)
|
||||||
self.gradient_checkpointing = True
|
self.gradient_checkpointing = True
|
||||||
LOG.warning(
|
if self.adapter and "lora" in self.adapter:
|
||||||
"`offload` now uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
|
LOG.warning(
|
||||||
)
|
"offloading with CUDA streams is not supported for LoRA adapters, using the `activation_offloading: legacy` implementation."
|
||||||
self.activation_offloading = True
|
)
|
||||||
|
self.activation_offloading = "legacy"
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
"`offload` uses a new stream implementation; to use the previous implementation, use `activation_offloading: legacy`"
|
||||||
|
)
|
||||||
|
self.activation_offloading = True
|
||||||
if self.gradient_checkpointing == "offload_disk":
|
if self.gradient_checkpointing == "offload_disk":
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
||||||
@@ -1103,6 +1121,19 @@ class ModelCompatibilityValidationMixin:
|
|||||||
self.activation_offloading = "disk"
|
self.activation_offloading = "disk"
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_activation_offloading_w_lora(self):
|
||||||
|
if (
|
||||||
|
self.activation_offloading is True
|
||||||
|
and self.adapter
|
||||||
|
and "lora" in self.adapter
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"activation_offloading with CUDA streams is not supported for LoRA adapters. Setting `activation_offloading: legacy`"
|
||||||
|
)
|
||||||
|
self.activation_offloading = "legacy"
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_activation_offloading_wo_gc(self):
|
def check_activation_offloading_wo_gc(self):
|
||||||
if self.activation_offloading and not self.gradient_checkpointing:
|
if self.activation_offloading and not self.gradient_checkpointing:
|
||||||
@@ -1203,13 +1234,13 @@ class ComplexValidationMixin:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_sequence_parallel_degree(self):
|
def check_context_parallel_size(self):
|
||||||
if not self.sequence_parallel_degree:
|
if not self.context_parallel_size:
|
||||||
self.sequence_parallel_degree = 1
|
self.context_parallel_size = 1
|
||||||
elif self.sequence_parallel_degree > 1:
|
elif self.context_parallel_size > 1:
|
||||||
if not self.flash_attention:
|
if not self.flash_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
"flash_attention: true must be set with context_parallel_size > 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sample_packing and self.micro_batch_size > 1:
|
if self.sample_packing and self.micro_batch_size > 1:
|
||||||
@@ -1222,14 +1253,14 @@ class ComplexValidationMixin:
|
|||||||
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
|
||||||
except ImportError as exception:
|
except ImportError as exception:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"sequence_parallel_degree > 1 but ring_flash_attn is not installed. "
|
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
) from exception
|
) from exception
|
||||||
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Sequence parallelism (SP) is enabled with "
|
"Sequence parallelism (SP) is enabled with "
|
||||||
f"sequence_parallel_degree={self.sequence_parallel_degree}. "
|
f"context_parallel_size={self.context_parallel_size}. "
|
||||||
"Please note that logged losses may differ slightly to the non-SP "
|
"Please note that logged losses may differ slightly to the non-SP "
|
||||||
"losses due to transformers Trainer implementation details. "
|
"losses due to transformers Trainer implementation details. "
|
||||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
@@ -1240,7 +1271,7 @@ class ComplexValidationMixin:
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_ring_attn_func(self):
|
def validate_ring_attn_func(self):
|
||||||
if getattr(self, "sequence_parallel_degree", 1) == 1:
|
if getattr(self, "context_parallel_size", 1) == 1:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if self.ring_attn_func is not None:
|
if self.ring_attn_func is not None:
|
||||||
|
|||||||
@@ -442,7 +442,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.context_parallel_size
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
@@ -484,7 +484,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
math.floor(
|
math.floor(
|
||||||
data_loader_len
|
data_loader_len
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.context_parallel_size
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -511,7 +511,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
math.ceil(
|
math.ceil(
|
||||||
len(train_dataset)
|
len(train_dataset)
|
||||||
* cfg.num_epochs
|
* cfg.num_epochs
|
||||||
* cfg.sequence_parallel_degree
|
* cfg.context_parallel_size
|
||||||
* cfg.tensor_parallel_size
|
* cfg.tensor_parallel_size
|
||||||
/ cfg.batch_size
|
/ cfg.batch_size
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ def fixture_base_cfg():
|
|||||||
"dataloader_num_workers": 1,
|
"dataloader_num_workers": 1,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
"dataloader_prefetch_factor": 2,
|
"dataloader_prefetch_factor": 2,
|
||||||
"sequence_parallel_degree": 1,
|
"context_parallel_size": 1,
|
||||||
"tensor_parallel_size": 1,
|
"tensor_parallel_size": 1,
|
||||||
# Dtype
|
# Dtype
|
||||||
"fp16": False,
|
"fp16": False,
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class TestSequenceParallelism:
|
|||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"ring_attn_func": ring_attn_func,
|
"ring_attn_func": ring_attn_func,
|
||||||
"save_first_step": False,
|
"save_first_step": False,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ class TestRingAttention:
|
|||||||
|
|
||||||
# Call register_ring_attn with size 4
|
# Call register_ring_attn with size 4
|
||||||
register_ring_attn(
|
register_ring_attn(
|
||||||
sequence_parallel_degree=4,
|
context_parallel_size=4,
|
||||||
heads_k_stride=1,
|
heads_k_stride=1,
|
||||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||||
)
|
)
|
||||||
@@ -156,24 +156,24 @@ class TestConfigValidation:
|
|||||||
[
|
[
|
||||||
# Valid configuration
|
# Valid configuration
|
||||||
(
|
(
|
||||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
{"context_parallel_size": 2, "flash_attention": True},
|
||||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
{"context_parallel_size": 2, "flash_attention": True},
|
||||||
True,
|
True,
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
# Default sequence_parallel_degree
|
# Default context_parallel_size
|
||||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
({}, {"context_parallel_size": 1}, True, None),
|
||||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
# Invalid: context_parallel_size > 1 without flash_attention
|
||||||
(
|
(
|
||||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
{"context_parallel_size": 2, "flash_attention": False},
|
||||||
None,
|
None,
|
||||||
False,
|
False,
|
||||||
"flash_attention: true must be set",
|
"flash_attention: true must be set",
|
||||||
),
|
),
|
||||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
# Invalid: context_parallel_size > 1 with sample_packing and micro_batch_size > 1
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
@@ -186,13 +186,13 @@ class TestConfigValidation:
|
|||||||
# Valid: Basic GRPO config
|
# Valid: Basic GRPO config
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"trl": {"use_liger_loss": True},
|
"trl": {"use_liger_loss": True},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"trl": TRLConfig(use_liger_loss=True),
|
"trl": TRLConfig(use_liger_loss=True),
|
||||||
@@ -204,7 +204,7 @@ class TestConfigValidation:
|
|||||||
(
|
(
|
||||||
{
|
{
|
||||||
"rl": "grpo",
|
"rl": "grpo",
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"trl": {"use_liger_loss": True},
|
"trl": {"use_liger_loss": True},
|
||||||
@@ -262,7 +262,7 @@ class TestConfigValidation:
|
|||||||
|
|
||||||
# Apply updates to base config
|
# Apply updates to base config
|
||||||
cfg = base_cfg | {
|
cfg = base_cfg | {
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": sample_packing,
|
"sample_packing": sample_packing,
|
||||||
}
|
}
|
||||||
@@ -282,7 +282,7 @@ class TestConfigValidation:
|
|||||||
|
|
||||||
# Invalid configuration with invalid ring_attn_func
|
# Invalid configuration with invalid ring_attn_func
|
||||||
cfg = base_cfg | {
|
cfg = base_cfg | {
|
||||||
"sequence_parallel_degree": 2,
|
"context_parallel_size": 2,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"ring_attn_func": "INVALID_FUNC",
|
"ring_attn_func": "INVALID_FUNC",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,83 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E tests for activation offloading
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
|
|
||||||
|
|
||||||
class TestActivationOffloading:
|
|
||||||
"""
|
|
||||||
E2E test cases for activation offloading
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"adapter",
|
|
||||||
["lora", "qlora", None],
|
|
||||||
)
|
|
||||||
def test_activation_offloading(
|
|
||||||
self,
|
|
||||||
temp_dir,
|
|
||||||
adapter,
|
|
||||||
):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"val_set_size": 0.0,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
"eos_token": "<|im_end|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"chat_template": "chatml",
|
|
||||||
"path": "mlabonne/FineTome-100k",
|
|
||||||
"type": "chat_template",
|
|
||||||
"split": "train[:10%]",
|
|
||||||
"field_messages": "conversations",
|
|
||||||
"message_field_role": "from",
|
|
||||||
"message_field_content": "value",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 2,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"flash_attention": True,
|
|
||||||
"sample_packing": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
"save_safetensors": True,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
"activation_offloading": True,
|
|
||||||
"save_first_step": False,
|
|
||||||
"lora_r": 8,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if adapter == "lora":
|
|
||||||
cfg["adapter"] = "lora"
|
|
||||||
if adapter == "qlora":
|
|
||||||
cfg["adapter"] = "qlora"
|
|
||||||
cfg["load_in_4bit"] = True
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
@@ -21,6 +21,62 @@ class TestActivationOffloading:
|
|||||||
assert cfg.gradient_checkpointing is True
|
assert cfg.gradient_checkpointing is True
|
||||||
assert cfg.activation_offloading is True
|
assert cfg.activation_offloading is True
|
||||||
|
|
||||||
|
def test_gc_converts_offload_w_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_gc_converts_offload_w_qlora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing="offload",
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_impl_changes_w_lora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
adapter="lora",
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
|
def test_ac_impl_changes_w_qlora(self, min_base_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
activation_offloading=True,
|
||||||
|
adapter="qlora",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
| min_base_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
assert cfg.gradient_checkpointing is True
|
||||||
|
assert cfg.activation_offloading == "legacy"
|
||||||
|
|
||||||
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
|
def test_ac_offload_impl_noop_wo_adapter(self, min_base_cfg):
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user