Compare commits
30 Commits
feat/wizar
...
fa3-hopper
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9bdf4b1c23 | ||
|
|
d6f64a3684 | ||
|
|
0735454782 | ||
|
|
bb6464c4c6 | ||
|
|
323a9cb153 | ||
|
|
b22150751f | ||
|
|
8c4bc59bfc | ||
|
|
a064f1c9b4 | ||
|
|
fb5ef6d445 | ||
|
|
34b68ddaae | ||
|
|
9a3d0c919b | ||
|
|
bd34d0b861 | ||
|
|
37220ab90a | ||
|
|
e1b74d710b | ||
|
|
79daf5b934 | ||
|
|
ddd7c55576 | ||
|
|
65c6c98a76 | ||
|
|
4ef2e8293f | ||
|
|
c126d5cd04 | ||
|
|
9b0be4f15c | ||
|
|
a27b909c5c | ||
|
|
6cb07b9d12 | ||
|
|
288653adb6 | ||
|
|
3a5b495a74 | ||
|
|
f661858fc4 | ||
|
|
c837c4a424 | ||
|
|
c9797de6bb | ||
|
|
8f8a7afb05 | ||
|
|
86472715da | ||
|
|
c0a0c7534c |
11
.github/workflows/base.yml
vendored
11
.github/workflows/base.yml
vendored
@@ -47,11 +47,18 @@ jobs:
|
|||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
suffix: "-hopper"
|
||||||
|
torch_cuda_arch_list: "9.0+PTX"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -87,7 +94,7 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||||
|
|||||||
10
.github/workflows/main.yml
vendored
10
.github/workflows/main.yml
vendored
@@ -31,6 +31,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -94,6 +99,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
11
.github/workflows/multi-gpu-e2e.yml
vendored
11
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -32,21 +32,25 @@ jobs:
|
|||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
- cuda: 126
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
|
suffix: "-hopper"
|
||||||
|
num_gpus: 2
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
@@ -68,7 +72,6 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
|
||||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
9
.github/workflows/tests.yml
vendored
9
.github/workflows/tests.yml
vendored
@@ -295,6 +295,7 @@ jobs:
|
|||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests-1st:
|
docker-e2e-tests-1st:
|
||||||
|
# Run this job first as a gate for running the remainder of the test matrix
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
@@ -341,6 +342,8 @@ jobs:
|
|||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
|
# Only run the remainder of the matrix if the first e2e check passed;
|
||||||
|
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
@@ -365,6 +368,12 @@ jobs:
|
|||||||
pytorch: 2.7.0
|
pytorch: 2.7.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 128
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.7.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -139,7 +139,8 @@ quartodoc:
|
|||||||
- utils.optimizers.adopt
|
- utils.optimizers.adopt
|
||||||
- utils.data.pretraining
|
- utils.data.pretraining
|
||||||
- utils.data.sft
|
- utils.data.sft
|
||||||
- utils.gradient_checkpointing.unsloth
|
- utils.gradient_checkpointing.offload_cpu
|
||||||
|
- utils.gradient_checkpointing.offload_disk
|
||||||
- title: Schemas
|
- title: Schemas
|
||||||
desc: Pydantic data models for Axolotl config
|
desc: Pydantic data models for Axolotl config
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
|
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
|
||||||
|
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
fi
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=90 * 60,
|
timeout=90 * 60,
|
||||||
cpu=8.0,
|
cpu=16.0,
|
||||||
memory=131072 * N_GPUS,
|
memory=131072 * N_GPUS,
|
||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG CUDA_VERSION="11.8.0"
|
ARG CUDA_VERSION="12.4.1"
|
||||||
ARG CUDNN_VERSION="8"
|
ARG CUDNN_VERSION=""
|
||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
@@ -7,16 +7,16 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
|
|||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
ARG PYTHON_VERSION="3.10"
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.5.1"
|
||||||
ARG CUDA="118"
|
ARG CUDA="124"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
@@ -38,6 +38,10 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
|
||||||
|
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||||
pip3 install flash-attn==2.7.4.post1; \
|
pip3 install flash-attn==2.7.4.post1; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -539,7 +539,7 @@ train_on_inputs: false
|
|||||||
# Note that training loss may have an oscillating pattern with this enabled.
|
# Note that training loss may have an oscillating pattern with this enabled.
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
||||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
@@ -633,7 +633,9 @@ weight_decay:
|
|||||||
# adamw hyperparams
|
# adamw hyperparams
|
||||||
adam_beta1:
|
adam_beta1:
|
||||||
adam_beta2:
|
adam_beta2:
|
||||||
|
adam_beta3: # only used for CAME Optimizer
|
||||||
adam_epsilon:
|
adam_epsilon:
|
||||||
|
adam_epsilon2: # only used for CAME Optimizer
|
||||||
# Gradient clipping max norm
|
# Gradient clipping max norm
|
||||||
max_grad_norm:
|
max_grad_norm:
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
|
|||||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||||
format them.
|
format them.
|
||||||
|
|
||||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
|
||||||
format):
|
format):
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -120,6 +120,12 @@ axolotl train my_training.yml
|
|||||||
|
|
||||||
## Common Tasks {#sec-common-tasks}
|
## Common Tasks {#sec-common-tasks}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
|
||||||
|
The same yaml file is used for training, inference, and merging.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
### Testing Your Model {#sec-testing}
|
### Testing Your Model {#sec-testing}
|
||||||
|
|
||||||
After training, test your model:
|
After training, test your model:
|
||||||
@@ -128,6 +134,16 @@ After training, test your model:
|
|||||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
More details can be found in [Inference](inference.qmd).
|
||||||
|
|
||||||
|
### Using a UI {#sec-ui}
|
||||||
|
|
||||||
|
Launch a Gradio interface:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||||
|
```
|
||||||
|
|
||||||
### Preprocessing Data {#sec-preprocessing}
|
### Preprocessing Data {#sec-preprocessing}
|
||||||
|
|
||||||
For large datasets, preprocess first:
|
For large datasets, preprocess first:
|
||||||
@@ -136,14 +152,22 @@ For large datasets, preprocess first:
|
|||||||
axolotl preprocess my_training.yml
|
axolotl preprocess my_training.yml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using a UI {#sec-ui}
|
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
|
||||||
|
|
||||||
Launch a Gradio interface:
|
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
|
||||||
|
|
||||||
|
### Merging LoRA weights {#sec-merging-lora}
|
||||||
|
|
||||||
|
To merge the LoRA weights back into the base model, run:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The merged model will be saved in the `{output_dir}/merged` directory.
|
||||||
|
|
||||||
|
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
||||||
|
|
||||||
## Next Steps {#sec-next-steps}
|
## Next Steps {#sec-next-steps}
|
||||||
|
|
||||||
Now that you have the basics, you might want to:
|
Now that you have the basics, you might want to:
|
||||||
@@ -156,6 +180,7 @@ Now that you have the basics, you might want to:
|
|||||||
Check our other guides for details on these topics:
|
Check our other guides for details on these topics:
|
||||||
|
|
||||||
- [Configuration Guide](config.qmd) - Full configuration options
|
- [Configuration Guide](config.qmd) - Full configuration options
|
||||||
|
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
|
||||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||||
- [Multi-GPU Training](multi-gpu.qmd)
|
- [Multi-GPU Training](multi-gpu.qmd)
|
||||||
- [Multi-Node Training](multi-node.qmd)
|
- [Multi-Node Training](multi-node.qmd)
|
||||||
|
|||||||
@@ -342,13 +342,6 @@ def delinearize_llama4(model: str, output: str) -> None:
|
|||||||
do_delinearize_llama4(model, output)
|
do_delinearize_llama4(model, output)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
|
||||||
def wizard():
|
|
||||||
from axolotl.cli.wizard import do_wizard
|
|
||||||
|
|
||||||
do_wizard()
|
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(lm_eval)
|
cli.add_command(lm_eval)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,429 +0,0 @@
|
|||||||
"""Wizard for creating yaml configs."""
|
|
||||||
|
|
||||||
import click
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from packaging import version
|
|
||||||
from transformers.training_args import OptimizerNames
|
|
||||||
|
|
||||||
from axolotl.cli.art import print_axolotl_text_art
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
from axolotl.utils.models import load_model_config
|
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
|
||||||
|
|
||||||
|
|
||||||
def do_wizard():
|
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
# Ask where to save the config
|
|
||||||
cfg = DictDefault({})
|
|
||||||
config_path = click.prompt(
|
|
||||||
"Where do you want to save the config?", type=str, default="config.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ask base model
|
|
||||||
base_model = click.prompt("What base model do you want to use?", type=str)
|
|
||||||
cfg["base_model"] = base_model.strip()
|
|
||||||
|
|
||||||
# Ask whether want to enable Vision model
|
|
||||||
# TODO: check if model has vision layers instead of asking user
|
|
||||||
train_vision_model = click.confirm(
|
|
||||||
"If this model has vision layers, do you want to train them?", default=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if train_vision_model:
|
|
||||||
cfg["processor_type"] = "AutoProcessor"
|
|
||||||
cfg["skip_prepare_dataset"] = True
|
|
||||||
cfg["remove_unused_columns"] = False
|
|
||||||
cfg["sample_packing"] = False
|
|
||||||
|
|
||||||
# Ask whether they want to set any advanced model features (custom tokenizer, custom config, etc)
|
|
||||||
advanced_model_features = click.confirm(
|
|
||||||
"Do you want to set any advanced model features? (custom tokenizer, custom config, remote code etc)",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if advanced_model_features:
|
|
||||||
# Ask whether they want to use a custom config
|
|
||||||
base_model_config = click.prompt(
|
|
||||||
"What model config do you want to use? (leave blank for default)",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
|
|
||||||
if base_model_config:
|
|
||||||
cfg["base_model_config"] = base_model_config
|
|
||||||
|
|
||||||
# Ask whether they want to use a specific revision of the model
|
|
||||||
revision_of_model = click.prompt(
|
|
||||||
"What revision of the model do you want to use? (leave blank for default)",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
|
|
||||||
if revision_of_model:
|
|
||||||
cfg["revision_of_model"] = revision_of_model
|
|
||||||
|
|
||||||
# Ask whether they want to use a custom tokenizer
|
|
||||||
tokenizer_config = click.prompt(
|
|
||||||
"What tokenizer do you want to use? (leave blank for default)",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
|
|
||||||
if tokenizer_config:
|
|
||||||
cfg["tokenizer_config"] = tokenizer_config
|
|
||||||
|
|
||||||
# Ask whether they want to use remote code
|
|
||||||
trust_remote_code = click.confirm(
|
|
||||||
"Do you want to use remote code?", default=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if trust_remote_code:
|
|
||||||
cfg["trust_remote_code"] = trust_remote_code
|
|
||||||
|
|
||||||
# Whether to resize token embeddings
|
|
||||||
resize_token_embeddings_to_32x = click.confirm(
|
|
||||||
"Do you want to resize token embeddings to 32x?", default=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if resize_token_embeddings_to_32x:
|
|
||||||
cfg["resize_token_embeddings_to_32x"] = resize_token_embeddings_to_32x
|
|
||||||
|
|
||||||
# Whether to shrink embeddings to len(tokenizer)
|
|
||||||
shrink_embeddings = click.confirm(
|
|
||||||
"Do you want to shrink embeddings to len(tokenizer)?", default=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if shrink_embeddings:
|
|
||||||
cfg["shrink_embeddings"] = shrink_embeddings
|
|
||||||
|
|
||||||
# Whether to skip upcast embeddings
|
|
||||||
embeddings_skip_upcast = click.confirm(
|
|
||||||
"Do you want to skip upcast embeddings?", default=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if embeddings_skip_upcast:
|
|
||||||
cfg["embeddings_skip_upcast"] = embeddings_skip_upcast
|
|
||||||
|
|
||||||
# Whether to random init weights
|
|
||||||
random_init_weights = click.confirm(
|
|
||||||
"Do you want to random init weights?", default=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if random_init_weights:
|
|
||||||
cfg["random_init_weights"] = random_init_weights
|
|
||||||
|
|
||||||
# Get model type
|
|
||||||
config = load_model_config(cfg)
|
|
||||||
model_type = config.model_type
|
|
||||||
|
|
||||||
# Ask sequence length
|
|
||||||
sequence_length = click.prompt("What sequence length do you want to use?", type=int)
|
|
||||||
cfg["sequence_length"] = sequence_length
|
|
||||||
|
|
||||||
# Whether to turn on sample packing
|
|
||||||
if cfg["sample_packing"] is None:
|
|
||||||
cfg["sample_packing"] = click.confirm(
|
|
||||||
"Do you want to turn on sample packing? This will speed up training by packing multiple samples into a single batch.",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg["sample_packing"]:
|
|
||||||
cfg["pad_to_sequence_len"] = True
|
|
||||||
|
|
||||||
# Whether to turn off eval sample packing
|
|
||||||
no_eval_sample_packing = click.confirm(
|
|
||||||
"Do you want to turn off eval sample packing? This will slow down evaluation but is recommended if you are using a small validation set.",
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if no_eval_sample_packing:
|
|
||||||
cfg["eval_sample_packing"] = False
|
|
||||||
|
|
||||||
# Hardware check
|
|
||||||
try:
|
|
||||||
is_ampere_or_newer = torch.cuda.get_device_capability()[0] >= 8
|
|
||||||
except RuntimeError:
|
|
||||||
is_ampere_or_newer = False
|
|
||||||
except AssertionError: # this is raised if no cuda is available
|
|
||||||
is_ampere_or_newer = False
|
|
||||||
|
|
||||||
# Get num gpus
|
|
||||||
try:
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
except RuntimeError:
|
|
||||||
num_gpus = 0
|
|
||||||
|
|
||||||
# Get torch version
|
|
||||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
|
||||||
|
|
||||||
is_torch_2_6_or_newer = version.parse(torch_version) >= version.parse("2.6.0")
|
|
||||||
|
|
||||||
# Whether to turn on attention
|
|
||||||
opt = ["xformers", "sdp"]
|
|
||||||
|
|
||||||
if is_ampere_or_newer:
|
|
||||||
opt.append("flash")
|
|
||||||
|
|
||||||
if is_torch_2_6_or_newer:
|
|
||||||
opt.append("flex")
|
|
||||||
|
|
||||||
if cfg["sample_packing"]:
|
|
||||||
if "flash" in opt:
|
|
||||||
default_opt = "flash"
|
|
||||||
elif "flex" in opt:
|
|
||||||
default_opt = "flex"
|
|
||||||
else:
|
|
||||||
default_opt = opt[0]
|
|
||||||
|
|
||||||
attention = click.prompt(
|
|
||||||
"Which attention backend do you want to use? Sample packing requires an attention backend to be set.",
|
|
||||||
type=click.Choice(opt),
|
|
||||||
default=default_opt,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# non-sample packing supports no attention and S2
|
|
||||||
opt.extend(["none", "s2"])
|
|
||||||
|
|
||||||
attention = click.prompt(
|
|
||||||
"Which attention backend do you want to use?",
|
|
||||||
type=click.Choice(opt),
|
|
||||||
default="none",
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention == "none":
|
|
||||||
attention = None
|
|
||||||
|
|
||||||
# TODO: if xformers, check if FA is installed
|
|
||||||
# TODO: flex doc mentioned requiring seq len to be divisible by 128. Unclear if limitation still exists
|
|
||||||
|
|
||||||
# TODO: requires #2489
|
|
||||||
cfg["attention"] = attention
|
|
||||||
|
|
||||||
# Whether to turn on gradient checkpointing
|
|
||||||
# TODO: need to wait for offload_disk PR to be merged
|
|
||||||
gradient_checkpointing = click.prompt(
|
|
||||||
"Which gradient checkpointing strategy do you want to use?",
|
|
||||||
type=click.Choice(["none", "true", "offload", "offload_disk"]),
|
|
||||||
default="true",
|
|
||||||
)
|
|
||||||
|
|
||||||
if gradient_checkpointing == "none":
|
|
||||||
gradient_checkpointing = False
|
|
||||||
elif gradient_checkpointing == "true":
|
|
||||||
gradient_checkpointing = True
|
|
||||||
|
|
||||||
# Ask whether to set use_reentrant
|
|
||||||
# TODO: get correct defaults based on SFT/RL mode and single/multigpu
|
|
||||||
# use_reentrant = click.confirm(
|
|
||||||
# "Do you want to set use_reentrant?",
|
|
||||||
# default=True,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if use_reentrant:
|
|
||||||
# cfg["use_reentrant"] = use_reentrant
|
|
||||||
|
|
||||||
# Optimizer
|
|
||||||
cfg["optimizer"] = click.prompt(
|
|
||||||
"Which optimizer do you want to use?",
|
|
||||||
type=click.Choice((OptimizerNames | CustomSupportedOptimizers)),
|
|
||||||
default=OptimizerNames.ADAMW_TORCH_FUSED,
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg["lr_scheduler"] = click.prompt(
|
|
||||||
"Which learning rate scheduler do you want to use?",
|
|
||||||
type=click.Choice(
|
|
||||||
[
|
|
||||||
"cosine",
|
|
||||||
"one_cycle",
|
|
||||||
"rex",
|
|
||||||
"log_sweep",
|
|
||||||
"linear",
|
|
||||||
"cosine_with_restarts",
|
|
||||||
"polynomial",
|
|
||||||
"constant",
|
|
||||||
"constant_with_warmup",
|
|
||||||
"inverse_sqrt",
|
|
||||||
"reduce_lr_on_plateau",
|
|
||||||
"cosine_with_min_lr",
|
|
||||||
"warmup_stable_decay",
|
|
||||||
]
|
|
||||||
),
|
|
||||||
default="cosine",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plugins
|
|
||||||
|
|
||||||
cfg["plugins"] = []
|
|
||||||
|
|
||||||
# Whether to turn on cut cross entropy
|
|
||||||
if is_ampere_or_newer:
|
|
||||||
# Note: This may error if users don't have CCE installed
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
|
|
||||||
CUT_CROSS_ENTROPY_MODEL_MAPPING,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
|
|
||||||
cut_cross_entropy = click.confirm(
|
|
||||||
"Do you want to turn on cut cross entropy? This will save VRAM if the model has a large vocab size.",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cut_cross_entropy:
|
|
||||||
cfg["plugins"].append(
|
|
||||||
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg["cut_cross_entropy"] = True
|
|
||||||
|
|
||||||
use_liger_kernel = click.confirm(
|
|
||||||
"Do you want to use the liger kernel? This will speed up training and save VRAM.",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_liger_kernel:
|
|
||||||
cfg["plugins"].append("axolotl.integrations.liger.LigerPlugin")
|
|
||||||
|
|
||||||
cfg["liger_rope"] = click.confirm(
|
|
||||||
"Do you want to enable liger rope?",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg["liger_rms_norm"] = click.confirm(
|
|
||||||
"Do you want to enable liger rms norm?",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg["liger_glu_activation"] = click.confirm(
|
|
||||||
"Do you want to enable liger glu activation?",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg["liger_layer_norm"] = click.confirm(
|
|
||||||
"Do you want to enable liger layer norm?",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg["cut_cross_entropy"] is not True:
|
|
||||||
cfg["liger_fused_linear_cross_entropy"] = click.confirm(
|
|
||||||
"Do you want to enable liger fused linear cross entropy?",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: lora kernels (but they auto enable via validator already)
|
|
||||||
|
|
||||||
# TODO: is there incompat between torch compile and liger?
|
|
||||||
cfg["torch_compile"] = click.confirm(
|
|
||||||
"Do you want to enable torch compile?",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Multi-gpu
|
|
||||||
if num_gpus > 1:
|
|
||||||
# Ask whether to use DDP/Deepspeed/FSDP
|
|
||||||
multi_gpu_mode = click.prompt(
|
|
||||||
"Which multi-gpu mode do you want to use?",
|
|
||||||
type=click.Choice(["ddp", "deepspeed", "fsdp"]),
|
|
||||||
default="ddp",
|
|
||||||
)
|
|
||||||
|
|
||||||
if multi_gpu_mode == "deepspeed":
|
|
||||||
# Ask which deepspeed config to use
|
|
||||||
cfg["deepspeed"] = click.prompt(
|
|
||||||
"Which deepspeed config do you want to use? The higher the number, the more VRAM you will save, but the slower it will run.",
|
|
||||||
type=click.Choice(
|
|
||||||
[
|
|
||||||
"zero1.json",
|
|
||||||
"zero1_torch_compile.json",
|
|
||||||
"zero2.json",
|
|
||||||
"zero3.json",
|
|
||||||
"zero3_bf16.json",
|
|
||||||
"zero3_bf16_cpuoffload_all.json",
|
|
||||||
"zero3_bf16_cpuoffload_params.json",
|
|
||||||
]
|
|
||||||
),
|
|
||||||
default="zero1.json",
|
|
||||||
)
|
|
||||||
elif multi_gpu_mode == "fsdp":
|
|
||||||
fsdp_version = click.prompt(
|
|
||||||
"Which fsdp version do you want to use?",
|
|
||||||
type=click.Choice([1, 2]),
|
|
||||||
default=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Handle FSDP config
|
|
||||||
|
|
||||||
if fsdp_version == 1:
|
|
||||||
cfg["fsdp"] = ["full_shard", "auto_wrap"]
|
|
||||||
|
|
||||||
# Ask which state dict type to use
|
|
||||||
fsdp_state_dict_type = click.prompt(
|
|
||||||
"Which fsdp state dict type do you want to use?",
|
|
||||||
type=click.Choice(["FULL_STATE_DICT", "SHARDED_STATE_DICT"]),
|
|
||||||
default="FULL_STATE_DICT",
|
|
||||||
)
|
|
||||||
|
|
||||||
fsdp_offload_params = click.confirm(
|
|
||||||
"Do you want to offload parameters?",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: can we load the model class and auto pull a default for this?
|
|
||||||
fsdp_transformer_layer_cls_to_wrap = click.prompt(
|
|
||||||
"Which transformer layer class to wrap? It is usually the Decoder layer class.",
|
|
||||||
type=str,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: add other options
|
|
||||||
|
|
||||||
cfg["fsdp_config"] = {
|
|
||||||
"state_dict_type": fsdp_state_dict_type,
|
|
||||||
"offload_params": fsdp_offload_params,
|
|
||||||
"transformer_layer_cls_to_wrap": fsdp_transformer_layer_cls_to_wrap,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif fsdp_version == 2:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
# Training mode (sft or rl)
|
|
||||||
training_mode = click.prompt(
|
|
||||||
"Which training mode do you want to use?",
|
|
||||||
type=click.Choice(["sft", "rl"]),
|
|
||||||
default="sft",
|
|
||||||
)
|
|
||||||
|
|
||||||
if training_mode == "rl":
|
|
||||||
cfg["rl"] = click.prompt(
|
|
||||||
"Which rl mode do you want to use?",
|
|
||||||
type=click.Choice(["dpo", "ipo", "orpo", "kto", "grpo", "simpo"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: handle RL options
|
|
||||||
|
|
||||||
# Whether to use adapter
|
|
||||||
|
|
||||||
# Get batch/grad accu
|
|
||||||
|
|
||||||
# Get learning rate
|
|
||||||
|
|
||||||
# Get weight decay
|
|
||||||
|
|
||||||
# Get max grad norm
|
|
||||||
|
|
||||||
# Get num train epochs
|
|
||||||
|
|
||||||
# Get warmup ratio
|
|
||||||
|
|
||||||
# Get save ratio
|
|
||||||
|
|
||||||
# Get eval ratio
|
|
||||||
|
|
||||||
# Get dataset config
|
|
||||||
|
|
||||||
# Load metric tracker
|
|
||||||
|
|
||||||
# Save config to yaml
|
|
||||||
# TODO: improve output yaml formatting. Need to add comments to help separate sections
|
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
|
||||||
yaml.dump(cfg.to_dict(), f, sort_keys=False)
|
|
||||||
@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||||
if self.cfg.adam_beta2:
|
if self.cfg.adam_beta2:
|
||||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||||
|
if self.cfg.adam_beta3:
|
||||||
|
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
|
||||||
if self.cfg.adam_epsilon:
|
if self.cfg.adam_epsilon:
|
||||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||||
|
if self.cfg.adam_epsilon2:
|
||||||
|
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
|
||||||
if self.cfg.max_grad_norm:
|
if self.cfg.max_grad_norm:
|
||||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||||
|
|
||||||
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
||||||
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
||||||
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
|
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
|
||||||
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
||||||
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
||||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||||
@@ -1170,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
trainer_kwargs["peft_config"] = self.peft_config
|
if self.cfg.rl is not RLType.GRPO:
|
||||||
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -14,7 +13,7 @@ from accelerate.utils import (
|
|||||||
broadcast_object_list,
|
broadcast_object_list,
|
||||||
gather,
|
gather,
|
||||||
gather_object,
|
gather_object,
|
||||||
is_peft_model,
|
is_peft_available,
|
||||||
)
|
)
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -30,15 +29,13 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from transformers.utils import is_peft_available
|
|
||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.data_utils import (
|
from trl.data_utils import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
is_conversational,
|
is_conversational,
|
||||||
maybe_apply_chat_template,
|
maybe_apply_chat_template,
|
||||||
)
|
)
|
||||||
from trl.extras.profiling import profiling_context, profiling_decorator
|
from trl.extras.profiling import profiling_context
|
||||||
from trl.import_utils import is_deepspeed_available
|
|
||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
from trl.trainer.grpo_config import GRPOConfig
|
from trl.trainer.grpo_config import GRPOConfig
|
||||||
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
from trl.trainer.grpo_trainer import RewardFunc, nanstd
|
||||||
@@ -52,62 +49,12 @@ if is_peft_available():
|
|||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from peft import PeftConfig
|
from peft import PeftConfig
|
||||||
|
|
||||||
if is_deepspeed_available():
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for axolotl helpers"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
@profiling_decorator
|
|
||||||
def _move_model_to_vllm(self):
|
|
||||||
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
|
||||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
|
||||||
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
|
||||||
gather_if_zero3 = (
|
|
||||||
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_peft_model(self.model):
|
|
||||||
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
|
||||||
# adapters in a sharded manner is not supported.
|
|
||||||
with gather_if_zero3(list(self.model.parameters())):
|
|
||||||
self.model.merge_adapter()
|
|
||||||
|
|
||||||
# Update vLLM weights while parameters are gathered
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
|
||||||
name = (
|
|
||||||
name.removeprefix("base_model.model.")
|
|
||||||
.removeprefix("base_model.model.")
|
|
||||||
.replace(".base_layer", "")
|
|
||||||
)
|
|
||||||
if self.model.prefix in name:
|
|
||||||
continue
|
|
||||||
# When module to save, remove its prefix and discard the original module
|
|
||||||
if "original_module" in name:
|
|
||||||
continue
|
|
||||||
name = name.replace("modules_to_save.default.", "")
|
|
||||||
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Unmerge adapters while parameters are still gathered
|
|
||||||
self.model.unmerge_adapter()
|
|
||||||
# Parameters will automatically be repartitioned when exiting the context
|
|
||||||
else:
|
|
||||||
# For non-PEFT models, simply gather and update each parameter individually.
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
with gather_if_zero3([param]):
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.update_named_param(name, param.data)
|
|
||||||
|
|
||||||
# Reset cache on main process
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
self.vllm_client.reset_prefix_cache()
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||||
|
|||||||
@@ -227,6 +227,19 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
adam_beta3: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
adam_epsilon2: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# multi-modal section
|
# multi-modal section
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
|||||||
@@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.cohere.modeling_cohere import (
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
COHERE_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.gemma.modeling_gemma import (
|
from transformers.models.gemma.modeling_gemma import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -20,15 +20,11 @@ from torch import nn
|
|||||||
from transformers.cache_utils import Cache, HybridCache
|
from transformers.cache_utils import Cache, HybridCache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA3_INPUTS_DOCSTRING,
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
Gemma3CausalLMOutputWithPast,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
@@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
@@ -170,10 +162,6 @@ def cce_forward(
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -19,15 +19,9 @@ from transformers.modeling_outputs import (
|
|||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
LLAMA_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -16,22 +16,12 @@ from torch import nn
|
|||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.models.llama4.modeling_llama4 import (
|
from transformers.models.llama4.modeling_llama4 import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
LLAMA4_INPUTS_DOCSTRING,
|
|
||||||
Llama4CausalLMOutputWithPast,
|
Llama4CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
@@ -160,9 +150,6 @@ def cce_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None, # type: ignore
|
input_ids: torch.LongTensor | None = None, # type: ignore
|
||||||
|
|||||||
@@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import (
|
|||||||
Mistral3CausalLMOutputWithPast,
|
Mistral3CausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
MISTRAL_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
@@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward(
|
def cce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor | None = None,
|
input_ids: torch.LongTensor | None = None,
|
||||||
|
|||||||
@@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
apply_lce,
|
apply_lce,
|
||||||
)
|
)
|
||||||
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
QWEN2MOE_INPUTS_DOCSTRING,
|
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
)
|
)
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
QWEN2_VL_INPUTS_DOCSTRING,
|
|
||||||
Qwen2VLCausalLMOutputWithPast,
|
Qwen2VLCausalLMOutputWithPast,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATCH_OPTS: PatchOptions | None = None
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def cce_forward_multimodal(
|
def cce_forward_multimodal(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import (
|
|||||||
TransformersModelT,
|
TransformersModelT,
|
||||||
apply_lce,
|
apply_lce,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
||||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
QWEN3_MOE_INPUTS_DOCSTRING,
|
|
||||||
KwargsForCausalLM,
|
KwargsForCausalLM,
|
||||||
MoeCausalLMOutputWithPast,
|
MoeCausalLMOutputWithPast,
|
||||||
MoeModelOutputWithPast,
|
MoeModelOutputWithPast,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
from transformers.processing_utils import Unpack
|
from transformers.processing_utils import Unpack
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
from transformers.utils.generic import can_return_tuple
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
@@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None
|
|||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
0
src/axolotl/integrations/liger/models/__init__.py
Normal file
0
src/axolotl/integrations/liger/models/__init__.py
Normal file
@@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
|
|
||||||
# @replace_return_docstrings(
|
|
||||||
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
# )
|
|
||||||
def lce_forward(
|
def lce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||||
from transformers.models.jamba.modeling_jamba import (
|
from transformers.models.jamba.modeling_jamba import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
JAMBA_INPUTS_DOCSTRING,
|
|
||||||
HybridMambaAttentionDynamicCache,
|
HybridMambaAttentionDynamicCache,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def lce_forward(
|
def lce_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.gemma3.modeling_gemma3 import (
|
from transformers.models.gemma3.modeling_gemma3 import (
|
||||||
_CONFIG_FOR_DOC,
|
|
||||||
GEMMA3_INPUTS_DOCSTRING,
|
|
||||||
Gemma3CausalLMOutputWithPast,
|
Gemma3CausalLMOutputWithPast,
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings_to_model_forward,
|
|
||||||
is_torchdynamo_compiling,
|
is_torchdynamo_compiling,
|
||||||
replace_return_docstrings,
|
|
||||||
)
|
)
|
||||||
from transformers.utils.deprecation import deprecate_kwarg
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
|
|
||||||
@replace_return_docstrings(
|
|
||||||
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
||||||
)
|
|
||||||
def new_forward(
|
def new_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""MLFlow module for trainer callbacks"""
|
"""MLFlow module for trainer callbacks"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
|
|||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
|
|
||||||
|
def should_log_artifacts() -> bool:
|
||||||
|
truths = ["TRUE", "1", "YES"]
|
||||||
|
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
|
||||||
|
|
||||||
|
|
||||||
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
"""Callback to save axolotl config to mlflow"""
|
"""Callback to save axolotl config to mlflow"""
|
||||||
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
try:
|
try:
|
||||||
with NamedTemporaryFile(
|
if should_log_artifacts():
|
||||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
with NamedTemporaryFile(
|
||||||
) as temp_file:
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
copyfile(self.axolotl_config_path, temp_file.name)
|
) as temp_file:
|
||||||
mlflow.log_artifact(temp_file.name, artifact_path="")
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
mlflow.log_artifact(temp_file.name, artifact_path="")
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the MLflow artifacts."
|
||||||
|
)
|
||||||
|
else:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"The Axolotl config has been saved to the MLflow artifacts."
|
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
|||||||
data_set = data_set.map(
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -484,7 +484,7 @@ def get_dataset_wrapper(
|
|||||||
}
|
}
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -5,8 +5,11 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
CPU_Offloaded_Gradient_Checkpointer,
|
||||||
|
)
|
||||||
|
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||||
|
Disco,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||||
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
if uses_gc_layers(decoder_layer):
|
if uses_gc_layers(decoder_layer):
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
decoder_layer,
|
decoder_layer,
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
(
|
||||||
|
decoder_layer.func.__self__
|
||||||
|
if isinstance(decoder_layer, partial)
|
||||||
|
else decoder_layer.__self__
|
||||||
|
),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||||
|
decoder_layer, *args, use_reentrant=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
if uses_gc_layers(decoder_layer):
|
||||||
|
return Disco.apply(
|
||||||
|
decoder_layer,
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Disco.apply(
|
||||||
(
|
(
|
||||||
decoder_layer.func.__self__
|
decoder_layer.func.__self__
|
||||||
if isinstance(decoder_layer, partial)
|
if isinstance(decoder_layer, partial)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Unsloth checkpointing"""
|
"""CPU offloaded checkpointing"""
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -26,7 +26,7 @@ else:
|
|||||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
torch.autograd.Function
|
torch.autograd.Function
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
"""
|
||||||
|
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from concurrent.futures import Future
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
# Setup logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiskOffloadManager:
|
||||||
|
"""
|
||||||
|
Manages offloaded tensors and handles prefetching in a separate thread.
|
||||||
|
Includes synchronization to prevent race conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefetch_size: int = 3,
|
||||||
|
prefetch_to_gpu: bool = True,
|
||||||
|
save_workers: int = 4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
prefetch_size: Maximum number of tensors to prefetch in the background.
|
||||||
|
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
|
||||||
|
save_workers: Maximum number of concurrent save operations.
|
||||||
|
"""
|
||||||
|
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
|
||||||
|
|
||||||
|
# Track tensor paths and their status
|
||||||
|
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
|
||||||
|
self.file_locks: Dict[str, threading.Lock] = (
|
||||||
|
{}
|
||||||
|
) # Maps file_path -> threading.Lock()
|
||||||
|
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
|
||||||
|
self.file_status: Dict[str, str] = {}
|
||||||
|
|
||||||
|
self.max_prefetch = prefetch_size
|
||||||
|
self.prefetch_to_gpu = prefetch_to_gpu
|
||||||
|
|
||||||
|
# Thread synchronization
|
||||||
|
self.manager_lock = threading.RLock() # Used for thread-safe operations
|
||||||
|
|
||||||
|
# Prefetch queue and cache
|
||||||
|
self.prefetch_queue: queue.Queue = queue.Queue()
|
||||||
|
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
|
||||||
|
|
||||||
|
# Save queue and thread pool
|
||||||
|
self.save_queue: queue.Queue = queue.Queue()
|
||||||
|
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
|
||||||
|
self.save_futures: Dict[str, Future] = {}
|
||||||
|
self.save_semaphore = threading.Semaphore(
|
||||||
|
save_workers * 2
|
||||||
|
) # Limit concurrent save operations
|
||||||
|
|
||||||
|
# Start prefetch worker thread
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
# start multiple threads for prefetching
|
||||||
|
self.prefetch_worker_count = 2
|
||||||
|
self.prefetch_workers = []
|
||||||
|
for _ in range(self.prefetch_worker_count):
|
||||||
|
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
|
||||||
|
worker.start()
|
||||||
|
self.prefetch_workers.append(worker)
|
||||||
|
|
||||||
|
# Start save worker thread
|
||||||
|
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
|
||||||
|
self.save_worker.start()
|
||||||
|
self.idx = 0
|
||||||
|
|
||||||
|
atexit.register(self.cleanup)
|
||||||
|
|
||||||
|
def _save_worker(self):
|
||||||
|
"""Background thread that processes the save queue"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
save_item = self.save_queue.get(timeout=0.5)
|
||||||
|
if save_item is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor, file_path = save_item
|
||||||
|
|
||||||
|
# Submit the save task to the thread pool
|
||||||
|
future = self.save_pool.submit(
|
||||||
|
self._save_tensor_to_disk, tensor, file_path
|
||||||
|
)
|
||||||
|
with self.manager_lock:
|
||||||
|
self.save_futures[file_path] = future
|
||||||
|
|
||||||
|
self.save_queue.task_done()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
|
||||||
|
"""Actually save the tensor to disk"""
|
||||||
|
try:
|
||||||
|
# Save tensor to disk
|
||||||
|
cpu_tensor = tensor.detach().cpu()
|
||||||
|
torch.save(cpu_tensor, file_path)
|
||||||
|
del cpu_tensor
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
# Mark file as ready
|
||||||
|
self.file_status[file_path] = "ready"
|
||||||
|
|
||||||
|
# Release semaphore
|
||||||
|
self.save_semaphore.release()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f"Error saving tensor to {file_path}: {e}")
|
||||||
|
with self.manager_lock:
|
||||||
|
self.file_status[file_path] = "error"
|
||||||
|
|
||||||
|
# Release semaphore
|
||||||
|
self.save_semaphore.release()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _prefetch_worker(self):
|
||||||
|
"""Background thread that loads tensors from disk ahead of time"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
file_path = self.prefetch_queue.get(timeout=0.5)
|
||||||
|
if file_path is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if file is available and not already in cache
|
||||||
|
with self.manager_lock:
|
||||||
|
if (
|
||||||
|
file_path not in self.file_status
|
||||||
|
or self.file_status[file_path] == "deleted"
|
||||||
|
):
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If file is still being saved, wait for it
|
||||||
|
if (
|
||||||
|
self.file_status[file_path] == "saving"
|
||||||
|
and file_path in self.save_futures
|
||||||
|
):
|
||||||
|
# Re-queue this prefetch request with a little delay
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.prefetch_queue.put(file_path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Mark file as being prefetched
|
||||||
|
self.file_status[file_path] = "prefetching"
|
||||||
|
|
||||||
|
# Load tensor from disk and store in cache
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
if self.prefetch_to_gpu:
|
||||||
|
tensor = torch.load(
|
||||||
|
file_path,
|
||||||
|
map_location=torch.device("cuda"),
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensor = torch.load(file_path, weights_only=True)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
self.prefetch_cache[file_path] = tensor
|
||||||
|
self.file_status[file_path] = "ready"
|
||||||
|
else:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) != "deleted":
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch error: File not found {file_path}"
|
||||||
|
)
|
||||||
|
self.file_status[file_path] = "missing"
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) != "deleted":
|
||||||
|
logger.warning(f"Prefetch error for {file_path}: {e}")
|
||||||
|
self.file_status[file_path] = "error"
|
||||||
|
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||||
|
continue
|
||||||
|
|
||||||
|
def save_tensor(self, tensor: torch.Tensor):
|
||||||
|
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
|
||||||
|
# Generate unique file path
|
||||||
|
self.idx += 1
|
||||||
|
file_path: str = os.path.join(
|
||||||
|
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
# Mark file as being saved
|
||||||
|
self.file_locks[file_path] = threading.Lock()
|
||||||
|
self.file_status[file_path] = "saving"
|
||||||
|
# Add to history
|
||||||
|
self.tensor_paths.append(file_path)
|
||||||
|
|
||||||
|
# Acquire semaphore to limit concurrent save operations
|
||||||
|
self.save_semaphore.acquire() # pylint: disable=consider-using-with
|
||||||
|
# Queue tensor for saving in background
|
||||||
|
self.save_queue.put((tensor.detach(), file_path))
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def wait_for_save(self, file_path, timeout=None) -> None:
|
||||||
|
"""Wait for a tensor to be saved to disk"""
|
||||||
|
start_time = time.time()
|
||||||
|
while timeout is None or time.time() - start_time < timeout:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) == "ready":
|
||||||
|
return
|
||||||
|
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
if future.done():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Small sleep to prevent CPU spinning
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
# Timeout
|
||||||
|
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_tensor(self, file_path, target_device="cuda"):
|
||||||
|
"""Load tensor from disk or prefetch cache with proper synchronization"""
|
||||||
|
# Wait for tensor to be saved if it's still in progress
|
||||||
|
self.wait_for_save(file_path)
|
||||||
|
|
||||||
|
tensor = None
|
||||||
|
|
||||||
|
# Try to get from cache first
|
||||||
|
with self.manager_lock:
|
||||||
|
# Check if tensor is already in cache
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
tensor = self.prefetch_cache[file_path]
|
||||||
|
del self.prefetch_cache[file_path]
|
||||||
|
self.file_status[file_path] = "loaded"
|
||||||
|
|
||||||
|
if tensor is not None:
|
||||||
|
# Ensure tensor is on correct device
|
||||||
|
if target_device != "cpu" and tensor.device.type == "cpu":
|
||||||
|
tensor = tensor.to(target_device, non_blocking=True)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# If not in cache, load directly from disk
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.error(f"File not found for loading: {file_path}")
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
tensor = torch.load(file_path, weights_only=True)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
self.file_status[file_path] = "loaded"
|
||||||
|
|
||||||
|
if target_device != "cpu":
|
||||||
|
tensor = tensor.to(target_device, non_blocking=True)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading tensor from {file_path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _safe_delete_file(self, file_path):
|
||||||
|
"""Safely delete a file with proper synchronization"""
|
||||||
|
with self.manager_lock:
|
||||||
|
# Make sure any save operation is completed
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
try:
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
del self.save_futures[file_path]
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error canceling save operation for {file_path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only delete if file exists and is not being prefetched
|
||||||
|
status = self.file_status.get(file_path)
|
||||||
|
if status in ["ready", "loaded", "error", "missing"]:
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
self.file_status[file_path] = "deleted"
|
||||||
|
return True
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(f"Error deleting file {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trigger_prefetch(self, n=None):
|
||||||
|
"""Trigger prefetching of the next N tensors with proper synchronization"""
|
||||||
|
if n is None:
|
||||||
|
n = self.max_prefetch
|
||||||
|
|
||||||
|
prefetch_paths = []
|
||||||
|
with self.manager_lock:
|
||||||
|
# Find files that are ready to be prefetched (not already in cache or being prefetched)
|
||||||
|
for path in reversed(self.tensor_paths):
|
||||||
|
if (
|
||||||
|
path not in self.prefetch_cache
|
||||||
|
and self.file_status.get(path) == "ready"
|
||||||
|
):
|
||||||
|
prefetch_paths.append(path)
|
||||||
|
if len(prefetch_paths) >= n:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Queue files for prefetching
|
||||||
|
for path in prefetch_paths:
|
||||||
|
self.prefetch_queue.put(path)
|
||||||
|
|
||||||
|
def cleanup_tensor(self, file_path: str):
|
||||||
|
"""Clean up a specific tensor file after it's been used"""
|
||||||
|
with self.manager_lock:
|
||||||
|
if file_path in self.tensor_paths:
|
||||||
|
self.tensor_paths.remove(file_path)
|
||||||
|
|
||||||
|
# Remove from prefetch cache if present
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
del self.prefetch_cache[file_path]
|
||||||
|
|
||||||
|
# Remove from save futures if present
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
del self.save_futures[file_path]
|
||||||
|
|
||||||
|
# Try to delete the file
|
||||||
|
self._safe_delete_file(file_path)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
|
||||||
|
self.stop_event.set()
|
||||||
|
|
||||||
|
# Cancel all pending save operations
|
||||||
|
with self.manager_lock:
|
||||||
|
for _, future in self.save_futures.items():
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
self.save_futures.clear()
|
||||||
|
|
||||||
|
# Drain the save queue
|
||||||
|
while not self.save_queue.empty():
|
||||||
|
try:
|
||||||
|
self.save_queue.get_nowait()
|
||||||
|
self.save_queue.task_done()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Shutdown the save pool
|
||||||
|
self.save_pool.shutdown(wait=False)
|
||||||
|
|
||||||
|
# Join the save worker thread
|
||||||
|
if self.save_worker.is_alive():
|
||||||
|
self.save_worker.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Join the prefetch worker threads
|
||||||
|
for thread in self.prefetch_workers:
|
||||||
|
if thread.is_alive():
|
||||||
|
thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Clear cache and remove all temporary files
|
||||||
|
with self.manager_lock:
|
||||||
|
self.prefetch_cache.clear()
|
||||||
|
paths_to_delete = list(self.tensor_paths)
|
||||||
|
self.tensor_paths.clear()
|
||||||
|
|
||||||
|
# Delete all temporary files
|
||||||
|
for path in paths_to_delete:
|
||||||
|
self._safe_delete_file(path)
|
||||||
|
|
||||||
|
# Remove temp directory
|
||||||
|
try:
|
||||||
|
if os.path.exists(self.temp_dir):
|
||||||
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class Disco(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||||
|
Advanced disk-based gradient checkpointer with prefetching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Shared manager instance across all checkpointing operations
|
||||||
|
_manager = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
|
||||||
|
"""Get or create the offload manager"""
|
||||||
|
if Disco._manager is None:
|
||||||
|
Disco._manager = DiskOffloadManager(
|
||||||
|
prefetch_size=prefetch_size,
|
||||||
|
prefetch_to_gpu=prefetch_to_gpu,
|
||||||
|
save_workers=save_workers,
|
||||||
|
)
|
||||||
|
return Disco._manager
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_fwd
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
forward_function,
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
prefetch_size=1,
|
||||||
|
prefetch_to_gpu=True,
|
||||||
|
save_workers=4,
|
||||||
|
):
|
||||||
|
"""Forward pass that offloads activations to disk asynchronously"""
|
||||||
|
# Get or create the manager
|
||||||
|
manager = Disco.get_instance(
|
||||||
|
prefetch_size=prefetch_size,
|
||||||
|
prefetch_to_gpu=prefetch_to_gpu,
|
||||||
|
save_workers=save_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save tensor to disk asynchronously
|
||||||
|
file_path = manager.save_tensor(hidden_states)
|
||||||
|
|
||||||
|
# Run forward pass immediately without waiting for save to complete
|
||||||
|
with torch.no_grad():
|
||||||
|
output = forward_function(hidden_states, *args)
|
||||||
|
|
||||||
|
# Store what we need for backward
|
||||||
|
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||||
|
ctx.file_path = file_path
|
||||||
|
ctx.forward_function = forward_function
|
||||||
|
ctx.args = args
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_bwd
|
||||||
|
def backward(ctx, *grad_outputs):
|
||||||
|
"""Backward pass that loads activations from disk with prefetching"""
|
||||||
|
# Get the manager
|
||||||
|
manager = Disco._manager
|
||||||
|
|
||||||
|
# Trigger prefetching for future tensors
|
||||||
|
# This happens at the start of backward, so should have time to complete
|
||||||
|
manager.trigger_prefetch()
|
||||||
|
|
||||||
|
# Load hidden states from disk or prefetch cache
|
||||||
|
file_path = ctx.file_path
|
||||||
|
try:
|
||||||
|
# Ensure the file is saved before we try to load it
|
||||||
|
manager.wait_for_save(file_path)
|
||||||
|
|
||||||
|
hidden_states = manager.load_tensor(file_path)
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
|
||||||
|
# Compute gradients
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
|
||||||
|
# Handle tuple outputs properly
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
if len(grad_outputs) == len(output):
|
||||||
|
torch.autograd.backward(output, grad_outputs)
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, grad_outputs[0])
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, grad_outputs[0])
|
||||||
|
|
||||||
|
# Clean up the file after we're done with it
|
||||||
|
manager.cleanup_tensor(file_path)
|
||||||
|
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
None, # forward_function
|
||||||
|
hidden_states.grad, # hidden_states grad
|
||||||
|
)
|
||||||
|
+ (None,) * len(ctx.args) # for each arg
|
||||||
|
+ (
|
||||||
|
None, # prefetch_size
|
||||||
|
None, # prefetch_to_gpu
|
||||||
|
None, # save_workers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in backward pass: {e}")
|
||||||
|
# Clean up the file even on error
|
||||||
|
manager.cleanup_tensor(file_path)
|
||||||
|
raise
|
||||||
@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
|
|||||||
is_local_main_process,
|
is_local_main_process,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
)
|
)
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
from axolotl.utils.gradient_checkpointing import (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper,
|
||||||
|
hf_grad_checkpoint_offload_wrapper,
|
||||||
|
)
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
@@ -620,8 +623,55 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
|
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||||
|
transformers.modeling_utils.checkpoint = (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
|
use_fa3 = False
|
||||||
|
if self.cfg.use_flash_attention_3 is True:
|
||||||
|
use_fa3 = True
|
||||||
|
elif self.cfg.use_flash_attention_3 == "auto":
|
||||||
|
if torch.cuda.get_device_capability() >= (9, 0):
|
||||||
|
# FA3 is only available on Hopper GPUs and newer
|
||||||
|
use_fa3 = True
|
||||||
|
if not importlib.util.find_spec("flash_attn_interface"):
|
||||||
|
use_fa3 = False
|
||||||
|
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
|
||||||
|
# this can happen when use_flash_attention_3 is explicity set to True
|
||||||
|
# and flash_attn_interface is not installed
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Please install the flash_attn_interface library to use Flash Attention 3.x"
|
||||||
|
)
|
||||||
|
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
|
||||||
|
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
|
||||||
|
from flash_attn_interface import (
|
||||||
|
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
||||||
|
)
|
||||||
|
|
||||||
|
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
|
return flash_attn_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
|
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||||
|
flash_attn_func_v3_wrapper
|
||||||
|
)
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
|
flash_attn_varlen_func_v3_wrapper
|
||||||
|
)
|
||||||
|
LOG.info("Switched to Flash Attention v3")
|
||||||
|
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
@@ -692,6 +742,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
patch_mllama()
|
patch_mllama()
|
||||||
|
|
||||||
|
# TODO deprecate soon
|
||||||
if self.model_config.model_type == "btlm":
|
if self.model_config.model_type == "btlm":
|
||||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||||
replace_btlm_attn_with_flash_attn,
|
replace_btlm_attn_with_flash_attn,
|
||||||
@@ -699,6 +750,7 @@ class ModelLoader:
|
|||||||
|
|
||||||
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
|
||||||
|
|
||||||
|
# TODO deprecate soon
|
||||||
if (
|
if (
|
||||||
self.model_config.model_type == "stablelm_epoch"
|
self.model_config.model_type == "stablelm_epoch"
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
|
|||||||
@@ -83,7 +83,6 @@ class AxolotlInputConfig(
|
|||||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||||
shrink_embeddings: bool | None = None
|
shrink_embeddings: bool | None = None
|
||||||
embeddings_skip_upcast: bool | None = None
|
embeddings_skip_upcast: bool | None = None
|
||||||
random_init_weights: bool | None = None
|
|
||||||
|
|
||||||
rl: RLType | None = None
|
rl: RLType | None = None
|
||||||
trl: TRLConfig | None = Field(
|
trl: TRLConfig | None = Field(
|
||||||
@@ -179,7 +178,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: torch.dtype | None
|
# torch_dtype: torch.dtype | None
|
||||||
|
|
||||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
|
||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||||
@@ -234,6 +233,7 @@ class AxolotlInputConfig(
|
|||||||
flash_attn_fuse_qkv: bool | None = None
|
flash_attn_fuse_qkv: bool | None = None
|
||||||
flash_attn_fuse_mlp: bool | None = None
|
flash_attn_fuse_mlp: bool | None = None
|
||||||
flash_optimum: bool | None = None
|
flash_optimum: bool | None = None
|
||||||
|
use_flash_attention_3: Literal["auto"] | bool | None = None
|
||||||
|
|
||||||
eager_attention: bool | None = None
|
eager_attention: bool | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ def temp_dir():
|
|||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def cleanup_monkeypatches():
|
def cleanup_monkeypatches():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
@@ -434,6 +435,19 @@ def cleanup_monkeypatches():
|
|||||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
original_trainer_training_step = Trainer.training_step
|
original_trainer_training_step = Trainer.training_step
|
||||||
|
original_fa_func = None
|
||||||
|
original_fa_varlen_func = None
|
||||||
|
if (
|
||||||
|
importlib.util.find_spec("flash_attn")
|
||||||
|
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
|
||||||
|
and hasattr(
|
||||||
|
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
|
||||||
|
)
|
||||||
|
):
|
||||||
|
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
|
||||||
|
original_fa_varlen_func = (
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
|
||||||
|
)
|
||||||
# monkey patches can happen inside the tests
|
# monkey patches can happen inside the tests
|
||||||
yield
|
yield
|
||||||
# Reset LlamaFlashAttention2 forward
|
# Reset LlamaFlashAttention2 forward
|
||||||
@@ -444,6 +458,11 @@ def cleanup_monkeypatches():
|
|||||||
original_trainer_inner_training_loop
|
original_trainer_inner_training_loop
|
||||||
)
|
)
|
||||||
Trainer.training_step = original_trainer_training_step
|
Trainer.training_step = original_trainer_training_step
|
||||||
|
if original_fa_func:
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
|
original_fa_varlen_func
|
||||||
|
)
|
||||||
|
|
||||||
# Reset other known monkeypatches
|
# Reset other known monkeypatches
|
||||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
@@ -458,6 +477,7 @@ def cleanup_monkeypatches():
|
|||||||
("transformers.trainer",),
|
("transformers.trainer",),
|
||||||
("transformers", ["Trainer"]),
|
("transformers", ["Trainer"]),
|
||||||
("transformers.loss.loss_utils",),
|
("transformers.loss.loss_utils",),
|
||||||
|
("transformers.modeling_flash_attention_utils",),
|
||||||
]
|
]
|
||||||
for module_name_tuple in modules_to_reset:
|
for module_name_tuple in modules_to_reset:
|
||||||
module_name = module_name_tuple[0]
|
module_name = module_name_tuple[0]
|
||||||
|
|||||||
@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC",
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
finally:
|
finally:
|
||||||
recursive_kill(vllm_process)
|
recursive_kill(vllm_process)
|
||||||
|
|
||||||
@pytest.mark.skip(reason="flaky test")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
|
||||||
# "VLLM_USE_V1": "0",
|
|
||||||
}
|
}
|
||||||
vllm_process = start_vllm(
|
vllm_process = start_vllm(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
|
|||||||
@@ -101,7 +101,13 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps",
|
"gradient_accumulation_steps",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
)
|
)
|
||||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
@pytest.mark.parametrize(
|
||||||
|
"use_flash_attention_3",
|
||||||
|
[False, "auto"],
|
||||||
|
)
|
||||||
|
def test_lora_ddp_packed(
|
||||||
|
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
|
||||||
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -138,6 +144,7 @@ class TestMultiGPULlama:
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
|
"use_flash_attention_3": use_flash_attention_3,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
|
|||||||
E2E tests for activation checkpointing
|
E2E tests for activation checkpointing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gradient_checkpointing",
|
||||||
|
["offload", "offload_disk"],
|
||||||
|
)
|
||||||
def test_activation_checkpointing_offload(
|
def test_activation_checkpointing_offload(
|
||||||
self,
|
self,
|
||||||
temp_dir,
|
temp_dir,
|
||||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
gradient_checkpointing,
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"gradient_checkpointing": "offload",
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ E2E tests for packed training
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
@@ -14,18 +13,17 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config, validate_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_tensorboard, with_temp_dir
|
from .utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestPackedLlama(unittest.TestCase):
|
class TestPackedLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Packed training of llama models
|
Test case for Packed training of llama models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_loss_packed(self, temp_dir):
|
def test_loss_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user