From c7b67906149d22f707b60ca8162007a1e4d48fa5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 12 May 2025 10:51:18 -0400 Subject: [PATCH 1/7] Various fixes for CI, save_only_model for RL, prevent packing multiprocessing deadlocks (#2661) * lean mistral ft tests, remove e2e torch 2.4.1 test * make sure to pass save_only_model for RL * more tests to make ci leaner, add cleanup to modal ci * fix module for import in e2e tests * use mp spawn to prevent deadlocks with packing * make sure cleanup shell script is executable when cloned out --- .github/workflows/tests.yml | 46 ++++++++++++-- cicd/__init__.py | 0 cicd/cicd.sh | 2 +- cicd/cleanup.py | 19 ++++++ cicd/cleanup.sh | 6 ++ cicd/e2e_tests.py | 65 +------------------ cicd/single_gpu.py | 66 ++++++++++++++++++++ src/axolotl/core/trainer_builder.py | 2 + src/axolotl/utils/samplers/multipack.py | 35 +++++++++-- tests/e2e/patched/test_4d_multipack_llama.py | 12 ++-- tests/e2e/patched/test_mistral_samplepack.py | 12 ++-- tests/e2e/patched/test_mixtral_samplepack.py | 12 ++-- tests/e2e/patched/test_phi_multipack.py | 12 ++-- 13 files changed, 190 insertions(+), 99 deletions(-) create mode 100644 cicd/__init__.py create mode 100644 cicd/cleanup.py create mode 100755 cicd/cleanup.sh create mode 100644 cicd/single_gpu.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2671cfd33..56d41a0d2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -335,12 +335,6 @@ jobs: pytorch: 2.6.0 num_gpus: 1 axolotl_extras: llmcompressor - - cuda: 124 - cuda_version: 12.4.1 - python_version: "3.11" - pytorch: 2.4.1 - num_gpus: 1 - axolotl_extras: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" @@ -377,3 +371,43 @@ jobs: - name: Run tests job on Modal run: | modal run cicd.e2e_tests + + docker-e2e-cleanup: + runs-on: [self-hosted, modal] + timeout-minutes: 90 + needs: [docker-e2e-tests] + + strategy: + fail-fast: false + matrix: + include: + - cuda: 124 + cuda_version: 12.4.1 + python_version: "3.11" + pytorch: 2.6.0 + num_gpus: 1 + axolotl_extras: vllm + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install Modal + run: | + python -m pip install --upgrade pip + pip install modal==0.71.8 jinja2 + - name: Update env vars + run: | + echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV + echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV + echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV + echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV + echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV + echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV + echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV + echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV + - name: Run tests job on Modal + run: | + modal run cicd.cleanup diff --git a/cicd/__init__.py b/cicd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 86cc4fa96..65ee8699d 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -18,7 +18,7 @@ pytest -v --durations=10 \ --cov-append # Run patched tests excluding lora kernels with coverage append -pytest -v --durations=10 \ +pytest --full-trace -vvv --durations=10 \ --ignore=tests/e2e/patched/lora_kernels \ /workspace/axolotl/tests/e2e/patched \ --cov=axolotl \ diff --git a/cicd/cleanup.py b/cicd/cleanup.py new file mode 100644 index 000000000..007489993 --- /dev/null +++ b/cicd/cleanup.py @@ -0,0 +1,19 @@ +"""Modal app to run axolotl GPU cleanup""" + +from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd + + +@app.function( + image=cicd_image, + timeout=60 * 60, + cpu=8.0, + memory=131072, + volumes=VOLUME_CONFIG, +) +def cleanup(): + run_cmd("./cicd/cleanup.sh", "/workspace/axolotl") + + +@app.local_entrypoint() +def main(): + cleanup.remote() diff --git a/cicd/cleanup.sh b/cicd/cleanup.sh new file mode 100755 index 000000000..4ea851bb4 --- /dev/null +++ b/cicd/cleanup.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -e + +# cleanup old cache files for datasets processing and intermediate mappings +find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \; +find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \; diff --git a/cicd/e2e_tests.py b/cicd/e2e_tests.py index 998f8c35d..2bc8ca072 100644 --- a/cicd/e2e_tests.py +++ b/cicd/e2e_tests.py @@ -1,69 +1,6 @@ """Modal app to run axolotl GPU tests""" -# pylint: disable=duplicate-code - -import os -import pathlib -import tempfile - -import jinja2 -import modal -from jinja2 import select_autoescape -from modal import App, Image - -cicd_path = pathlib.Path(__file__).parent.resolve() - -template_loader = jinja2.FileSystemLoader(searchpath=cicd_path) -template_env = jinja2.Environment( - loader=template_loader, autoescape=select_autoescape() -) -df_template = template_env.get_template("Dockerfile.jinja") - -df_args = { - "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), - "AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""), - "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"), - "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"), - "CUDA": os.environ.get("CUDA", "121"), - "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), - "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), - "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), - "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), - "HF_HOME": "/workspace/data/huggingface-cache/hub", -} - -dockerfile_contents = df_template.render(**df_args) - -temp_dir = tempfile.mkdtemp() -with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f: - f.write(dockerfile_contents) - -cicd_image = Image.from_dockerfile( - pathlib.Path(temp_dir) / "Dockerfile", - context_mount=None, - force_build=True, - gpu="A10G", -).env(df_args) - -app = App("Axolotl CI/CD", secrets=[]) - -hf_cache_volume = modal.Volume.from_name( - "axolotl-ci-hf-hub-cache", create_if_missing=True -) -VOLUME_CONFIG = { - "/workspace/data/huggingface-cache/hub": hf_cache_volume, -} - -N_GPUS = int(os.environ.get("N_GPUS", 1)) -GPU_CONFIG = modal.gpu.L40S(count=N_GPUS) - - -def run_cmd(cmd: str, run_folder: str): - import subprocess # nosec - - # Propagate errors from subprocess. - if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec - exit(exit_code) # pylint: disable=consider-using-sys-exit +from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd @app.function( diff --git a/cicd/single_gpu.py b/cicd/single_gpu.py new file mode 100644 index 000000000..d46d970cf --- /dev/null +++ b/cicd/single_gpu.py @@ -0,0 +1,66 @@ +"""Modal app to run axolotl GPU tests""" + +# pylint: disable=duplicate-code + +import os +import pathlib +import tempfile + +import jinja2 +import modal +from jinja2 import select_autoescape +from modal import App, Image + +cicd_path = pathlib.Path(__file__).parent.resolve() + +template_loader = jinja2.FileSystemLoader(searchpath=cicd_path) +template_env = jinja2.Environment( + loader=template_loader, autoescape=select_autoescape() +) +df_template = template_env.get_template("Dockerfile.jinja") + +df_args = { + "AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""), + "AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""), + "PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"), + "BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"), + "CUDA": os.environ.get("CUDA", "121"), + "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), + "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), + "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), + "CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""), + "HF_HOME": "/workspace/data/huggingface-cache/hub", +} + +dockerfile_contents = df_template.render(**df_args) + +temp_dir = tempfile.mkdtemp() +with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f: + f.write(dockerfile_contents) + +cicd_image = Image.from_dockerfile( + pathlib.Path(temp_dir) / "Dockerfile", + context_mount=None, + force_build=True, + gpu="A10G", +).env(df_args) + +app = App("Axolotl CI/CD", secrets=[]) + +hf_cache_volume = modal.Volume.from_name( + "axolotl-ci-hf-hub-cache", create_if_missing=True +) +VOLUME_CONFIG = { + "/workspace/data/huggingface-cache/hub": hf_cache_volume, +} + +N_GPUS = int(os.environ.get("N_GPUS", 1)) +GPU_CONFIG = modal.gpu.L40S(count=N_GPUS) + + +def run_cmd(cmd: str, run_folder: str): + import subprocess # nosec + + # Propagate errors from subprocess. + if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec + exit(exit_code) # pylint: disable=consider-using-sys-exit diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 5cb397b28..670561ede 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1057,6 +1057,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" + training_args_kwargs["save_only_model"] = self.cfg.save_only_model + if self.cfg.dataset_processes: training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index c38313c7c..2df2d9e19 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -6,7 +6,7 @@ into fixed-capacity batches to optimize memory usage and training throughput. import logging import math from concurrent.futures import ProcessPoolExecutor -from multiprocessing import cpu_count +from multiprocessing import cpu_count, get_context from typing import Iterable, Union import numba @@ -126,6 +126,7 @@ def pack_parallel( bin_size: int, num_processes: int | None = None, safe_mode: bool = True, + mp_start_method: str | None = "spawn", ): """ Pack sequences into bins using parallel processing @@ -137,7 +138,9 @@ def pack_parallel( bin_size: Maximum number of bins to use num_processes: Number of parallel processes to use safe_mode: If True, use a more conservative packing approach - + mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver'). + 'spawn' is often safer with Numba/PyTorch. + Set to None to use system default. Returns: List of bins, where each bin contains indices of sequences assigned to it """ @@ -154,9 +157,33 @@ def pack_parallel( # Process groups in parallel all_bins = [] - with ProcessPoolExecutor(max_workers=num_processes) as executor: - for group_bins in executor.map(_process_group, tasks): + + mp_ctx = None + if mp_start_method: + try: + mp_ctx = get_context(mp_start_method) + except ValueError: + LOG.warning( + f"Failed to get multiprocessing context '{mp_start_method}'. " + f"Falling back to default. Available: {get_context().get_all_start_methods()}" + ) + mp_ctx = ( + None # Fallback to default context if specified one is not available + ) + + if num_processes == 1: + LOG.debug("Using single process for pack_parallel, running sequentially.") + for task_args in tasks: + group_bins = _process_group(task_args) all_bins.extend(group_bins) + else: + # Use ProcessPoolExecutor only if num_processes > 1 + # Pass mp_context if available + with ProcessPoolExecutor( + max_workers=num_processes, mp_context=mp_ctx + ) as executor: + for group_bins in executor.map(_process_group, tasks): + all_bins.extend(group_bins) return all_bins diff --git a/tests/e2e/patched/test_4d_multipack_llama.py b/tests/e2e/patched/test_4d_multipack_llama.py index 270956883..12dd51c13 100644 --- a/tests/e2e/patched/test_4d_multipack_llama.py +++ b/tests/e2e/patched/test_4d_multipack_llama.py @@ -57,9 +57,9 @@ class Test4dMultipackLlama(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, + "max_steps": 5, + "save_steps": 3, + "eval_steps": 4, "fp16": True, } ) @@ -105,9 +105,9 @@ class Test4dMultipackLlama(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, + "max_steps": 5, + "save_steps": 3, + "eval_steps": 4, "fp16": True, } ) diff --git a/tests/e2e/patched/test_mistral_samplepack.py b/tests/e2e/patched/test_mistral_samplepack.py index ccfeb3d63..fe8fafb19 100644 --- a/tests/e2e/patched/test_mistral_samplepack.py +++ b/tests/e2e/patched/test_mistral_samplepack.py @@ -57,9 +57,9 @@ class TestMistral(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, + "max_steps": 5, + "save_steps": 3, + "eval_steps": 4, "bf16": "auto", } ) @@ -99,9 +99,9 @@ class TestMistral(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_torch_fused", "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, + "max_steps": 5, + "save_steps": 3, + "eval_steps": 4, "bf16": "auto", } ) diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index f035b1f28..ebc2ba092 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -54,9 +54,9 @@ class TestMixtral(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_bnb_8bit", "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, + "max_steps": 5, + "save_steps": 3, + "eval_steps": 4, "bf16": "auto", } ) @@ -93,9 +93,9 @@ class TestMixtral(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_bnb_8bit", "lr_scheduler": "cosine", - "max_steps": 20, - "save_steps": 10, - "eval_steps": 10, + "max_steps": 5, + "save_steps": 3, + "eval_steps": 4, "bf16": "auto", } ) diff --git a/tests/e2e/patched/test_phi_multipack.py b/tests/e2e/patched/test_phi_multipack.py index c42ed8baf..d8130d119 100644 --- a/tests/e2e/patched/test_phi_multipack.py +++ b/tests/e2e/patched/test_phi_multipack.py @@ -56,9 +56,9 @@ class TestPhiMultipack(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_bnb_8bit", "lr_scheduler": "cosine", - "max_steps": 20, - "eval_steps": 10, - "save_steps": 10, + "max_steps": 5, + "eval_steps": 3, + "save_steps": 4, "bf16": "auto", } ) @@ -108,9 +108,9 @@ class TestPhiMultipack(unittest.TestCase): "learning_rate": 0.00001, "optimizer": "adamw_bnb_8bit", "lr_scheduler": "cosine", - "max_steps": 20, - "eval_steps": 10, - "save_steps": 10, + "max_steps": 5, + "eval_steps": 3, + "save_steps": 4, "bf16": "auto", } ) From f34eef546a9171d5de1d4779c3ca75e275f45b1e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 12 May 2025 14:17:25 -0400 Subject: [PATCH 2/7] update doc and use P2P=LOC for brittle grpo test (#2649) * update doc and skip brittle grpo test * fix the path to run the multigpu tests * increase timeout, use LOC instead of NVL * typo * use hf cache from s3 backed cloudfront * mark grpo as flaky test dues to vllm start --- .github/workflows/multi-gpu-e2e.yml | 2 +- .github/workflows/tests.yml | 224 +++++++++++++++------------ cicd/e2e_tests.py | 2 +- codecov.yml | 2 +- docs/config.qmd | 1 + tests/e2e/multigpu/solo/test_grpo.py | 10 +- 6 files changed, 131 insertions(+), 110 deletions(-) diff --git a/.github/workflows/multi-gpu-e2e.yml b/.github/workflows/multi-gpu-e2e.yml index ffb3577ea..8c7692d13 100644 --- a/.github/workflows/multi-gpu-e2e.yml +++ b/.github/workflows/multi-gpu-e2e.yml @@ -3,7 +3,7 @@ name: docker-multigpu-tests-biweekly on: pull_request: paths: - - 'tests/e2e/multigpu/*.py' + - 'tests/e2e/multigpu/**.py' - 'requirements.txt' - 'setup.py' - 'pyproject.toml' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 56d41a0d2..c296e2314 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,96 +44,102 @@ jobs: env: SKIP: no-commit-to-branch - preload-cache: - name: Preload HF cache - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python_version: ["3.11"] - pytorch_version: ["2.6.0"] - timeout-minutes: 20 - - env: - AXOLOTL_IS_CI_CACHE_PRELOAD: "1" - - steps: - - name: Check out repository code - uses: actions/checkout@v4 - - - name: Restore HF cache - id: hf-cache-restore - uses: actions/cache/restore@v4 - with: - path: | - /home/runner/.cache/huggingface/hub/datasets--* - /home/runner/.cache/huggingface/hub/models--* - key: ${{ runner.os }}-hf-hub-cache-v2 - - - name: Setup Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python_version }} - cache: 'pip' # caching pip dependencies - - - name: upgrade pip - run: | - pip3 install --upgrade pip - pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel - - - name: Install PyTorch - run: | - pip3 install torch==${{ matrix.pytorch_version }} - - - name: Install dependencies - run: | - pip3 show torch - pip3 install --no-build-isolation -U -e . - python scripts/unsloth_install.py | sh - python scripts/cutcrossentropy_install.py | sh - pip3 install -r requirements-dev.txt -r requirements-tests.txt - - - name: Make sure PyTorch version wasn't clobbered - run: | - python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__" - - - name: Ensure axolotl CLI was installed - run: | - axolotl --help - - - name: Pre-Download dataset fixture - run: | - huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures - - - name: Run tests - run: | - pytest -v tests/conftest.py - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - files: ./coverage.xml - flags: unittests,pytorch-${{ matrix.pytorch_version }} - fail_ci_if_error: false - - - name: cleanup pip cache - run: | - find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; - - - name: Save HF cache - id: hf-cache - uses: actions/cache/save@v4 - with: - path: | - /home/runner/.cache/huggingface/hub/datasets--* - /home/runner/.cache/huggingface/hub/models--* - key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }} +# preload-cache: +# name: Preload HF cache +# runs-on: ubuntu-latest +# strategy: +# fail-fast: false +# matrix: +# python_version: ["3.11"] +# pytorch_version: ["2.6.0"] +# timeout-minutes: 20 +# +# env: +# AXOLOTL_IS_CI_CACHE_PRELOAD: "1" +# +# steps: +# - name: Check out repository code +# uses: actions/checkout@v4 +# +# - name: Restore HF cache +# id: hf-cache-restore +# uses: actions/cache/restore@v4 +# with: +# path: | +# /home/runner/.cache/huggingface/hub/datasets--* +# /home/runner/.cache/huggingface/hub/models--* +# key: ${{ runner.os }}-hf-hub-cache-v2 +# +# - name: Restore Cache from S3 +# id: hf-cache-restore-s3 +# run: | +# mkdir -p /home/runner/.cache/huggingface/hub +# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd +# +# - name: Setup Python +# uses: actions/setup-python@v5 +# with: +# python-version: ${{ matrix.python_version }} +# cache: 'pip' # caching pip dependencies +# +# - name: upgrade pip +# run: | +# pip3 install --upgrade pip +# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel +# +# - name: Install PyTorch +# run: | +# pip3 install torch==${{ matrix.pytorch_version }} +# +# - name: Install dependencies +# run: | +# pip3 show torch +# pip3 install --no-build-isolation -U -e . +# python scripts/unsloth_install.py | sh +# python scripts/cutcrossentropy_install.py | sh +# pip3 install -r requirements-dev.txt -r requirements-tests.txt +# +# - name: Make sure PyTorch version wasn't clobbered +# run: | +# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__" +# +# - name: Ensure axolotl CLI was installed +# run: | +# axolotl --help +# +# - name: Pre-Download dataset fixture +# run: | +# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures +# +# - name: Run tests +# run: | +# pytest -v tests/conftest.py +# +# - name: Upload coverage to Codecov +# uses: codecov/codecov-action@v5 +# with: +# token: ${{ secrets.CODECOV_TOKEN }} +# files: ./coverage.xml +# flags: unittests,pytorch-${{ matrix.pytorch_version }} +# fail_ci_if_error: false +# +# - name: cleanup pip cache +# run: | +# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; +# +# - name: Save HF cache +# id: hf-cache +# uses: actions/cache/save@v4 +# with: +# path: | +# /home/runner/.cache/huggingface/hub/datasets--* +# /home/runner/.cache/huggingface/hub/models--* +# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }} pytest: name: PyTest runs-on: ubuntu-latest - needs: [preload-cache] +# needs: [preload-cache] strategy: fail-fast: false matrix: @@ -145,14 +151,20 @@ jobs: - name: Check out repository code uses: actions/checkout@v4 - - name: Restore HF cache - id: hf-cache-restore - uses: actions/cache/restore@v4 - with: - path: | - /home/runner/.cache/huggingface/hub/datasets--* - /home/runner/.cache/huggingface/hub/models--* - key: ${{ runner.os }}-hf-hub-cache-v2 +# - name: Restore HF cache +# id: hf-cache-restore +# uses: actions/cache/restore@v4 +# with: +# path: | +# /home/runner/.cache/huggingface/hub/datasets--* +# /home/runner/.cache/huggingface/hub/models--* +# key: ${{ runner.os }}-hf-hub-cache-v2 + + - name: Restore Cache from S3 + id: hf-cache-restore-s3 + run: | + mkdir -p /home/runner/.cache/huggingface/hub + curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd - name: Setup Python uses: actions/setup-python@v5 @@ -210,7 +222,7 @@ jobs: pytest-sdist: name: PyTest from Source Dist runs-on: ubuntu-latest - needs: [preload-cache] +# needs: [preload-cache] strategy: fail-fast: false matrix: @@ -222,14 +234,20 @@ jobs: - name: Check out repository code uses: actions/checkout@v4 - - name: Restore HF cache - id: hf-cache-restore - uses: actions/cache/restore@v4 - with: - path: | - /home/runner/.cache/huggingface/hub/datasets--* - /home/runner/.cache/huggingface/hub/models--* - key: ${{ runner.os }}-hf-hub-cache-v2 +# - name: Restore HF cache +# id: hf-cache-restore +# uses: actions/cache/restore@v4 +# with: +# path: | +# /home/runner/.cache/huggingface/hub/datasets--* +# /home/runner/.cache/huggingface/hub/models--* +# key: ${{ runner.os }}-hf-hub-cache-v2 + + - name: Restore Cache from S3 + id: hf-cache-restore-s3 + run: | + mkdir -p /home/runner/.cache/huggingface/hub + curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd - name: Setup Python uses: actions/setup-python@v5 diff --git a/cicd/e2e_tests.py b/cicd/e2e_tests.py index 2bc8ca072..ce9c605c7 100644 --- a/cicd/e2e_tests.py +++ b/cicd/e2e_tests.py @@ -6,7 +6,7 @@ from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd @app.function( image=cicd_image, gpu=GPU_CONFIG, - timeout=60 * 60, + timeout=90 * 60, # 90 min cpu=8.0, memory=131072, volumes=VOLUME_CONFIG, diff --git a/codecov.yml b/codecov.yml index c85268b4c..2741b1758 100644 --- a/codecov.yml +++ b/codecov.yml @@ -19,7 +19,7 @@ coverage: if_no_uploads: error if_not_found: success if_ci_failed: error - only_pulls: false + only_pulls: true flags: null paths: null patch: diff --git a/docs/config.qmd b/docs/config.qmd index 1cff9e6f4..eba9f4881 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -505,6 +505,7 @@ save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of eac save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps save_total_limit: # Checkpoints saved at a time +save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints. # Maximum number of iterations to train for. It precedes num_epochs which means that # if both are set, num_epochs will not be guaranteed. # e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps diff --git a/tests/e2e/multigpu/solo/test_grpo.py b/tests/e2e/multigpu/solo/test_grpo.py index a34d4b3f8..a1eade531 100644 --- a/tests/e2e/multigpu/solo/test_grpo.py +++ b/tests/e2e/multigpu/solo/test_grpo.py @@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): """ ) + @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -227,7 +228,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): current_env = os.environ.copy() env = { - "NCCL_P2P_LEVEL": "NVL", + "NCCL_P2P_LEVEL": "LOC", **current_env, "CUDA_VISIBLE_DEVICES": "1", "VLLM_DISABLE_COMPILE_CACHE": "1", @@ -257,7 +258,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): f"{get_torch_dist_unique_port()}", ], env={ - "NCCL_P2P_LEVEL": "NVL", + "NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env, }, @@ -265,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): finally: recursive_kill(vllm_process) + @pytest.mark.skip(reason="flaky test") @pytest.mark.parametrize( "num_gpus", [1, 2], @@ -320,7 +322,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): current_env = os.environ.copy() env = { - "NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable + "NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable **current_env, "CUDA_VISIBLE_DEVICES": "1", "VLLM_DISABLE_COMPILE_CACHE": "1", @@ -350,7 +352,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs): f"{get_torch_dist_unique_port()}", ], env={ - "NCCL_P2P_LEVEL": "NVL", + "NCCL_P2P_LEVEL": "LOC", "NCCL_DEBUG": "INFO", **current_env, }, From 526ddb886d2f01e9482c69face7550fb052e4ad2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 12 May 2025 14:18:42 -0400 Subject: [PATCH 3/7] guard on deleting secrets from env (#2653) [skip ci] --- .runpod/src/handler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.runpod/src/handler.py b/.runpod/src/handler.py index 21073dff4..740c1ed1f 100644 --- a/.runpod/src/handler.py +++ b/.runpod/src/handler.py @@ -57,8 +57,10 @@ async def handler(job): logger.info("Training Complete.") # Cleanup - del os.environ["WANDB_API_KEY"] - del os.environ["HF_TOKEN"] + if "WANDB_API_KEY" in os.environ: + del os.environ["WANDB_API_KEY"] + if "HF_TOKEN" in os.environ: + del os.environ["HF_TOKEN"] runpod.serverless.start({"handler": handler, "return_aggregate_stream": True}) From 67c4ea9c7c55873070a9ec81f8fda708df6314b1 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 13 May 2025 03:23:53 +0700 Subject: [PATCH 4/7] fix: disable auto lora kernel if dropout nonzero (#2655) [skip ci] * fix: disable auto lora kernel if dropout nonzero * Add comment from PR feedback --------- Co-authored-by: Wing Lian --- src/axolotl/utils/schemas/config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9db374409..cd9891e04 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1345,6 +1345,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ): return data + # Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks + if data.get("lora_dropout") != 0: + return data + # Check multi-GPU compatibility capabilities = data.get("capabilities") is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1 From 80304c26a70e21ed8522fdbd53bcb290f9c6b7d3 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 12 May 2025 17:52:40 -0400 Subject: [PATCH 5/7] SP GRPO support + batch SP fixes (#2643) * ctx manager for SP * updates * update * further simplifying * simplifying * simplifying * reorg * batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * fix * fixes for batch API funcs, simplify * fix * grpo sp support * progress * stronger subclassing of TRL GRPO trainer; custom distributed sampler * subclassing constructor * progress * finalizing SP + GRPO trainer * minimize diffs to GRPO trainer * remove (most of) the custom GRPO trainer logic * debug * debug * update * update * update * progress * cleanup * cleanup * minor changes * update * update * update * small changes * updates * cleanup; torch.compile ring_flash_attn functions to prevent numerical instability; lint * spacing * cleanup; log in pydantic model config only on main process * remove comment * fix sp sampler, update to latest upstream code, doc * add docs * update quartodoc autodoc contents * fix, simplifications * fixes + simplifications * review comments * lint * removing main process only logs in favor of #2608 * fixes, additional smoke test * updatse * more tests * update * fix grad accum bug (sort of) * lint, tests * todo --- _quarto.yml | 17 +- docs/sequence_parallelism.qmd | 4 +- src/axolotl/common/datasets.py | 3 +- src/axolotl/core/trainer_builder.py | 83 ++- src/axolotl/core/trainers/__init__.py | 2 +- src/axolotl/core/trainers/base.py | 4 +- src/axolotl/core/trainers/dpo/__init__.py | 11 +- src/axolotl/core/trainers/grpo/__init__.py | 47 +- src/axolotl/core/trainers/grpo/args.py | 4 +- src/axolotl/core/trainers/grpo/sampler.py | 172 +++++ src/axolotl/core/trainers/grpo/trainer.py | 653 +++++++++++++++++- src/axolotl/core/trainers/mixins/__init__.py | 2 +- .../core/trainers/mixins/sequence_parallel.py | 228 +----- src/axolotl/core/training_args.py | 2 +- .../attention/ring_attn/__init__.py | 1 - .../attention/ring_attn/adapters/batch.py | 14 +- .../monkeypatch/attention/ring_attn/patch.py | 20 +- src/axolotl/train.py | 55 +- src/axolotl/utils/ctx_managers/__init__.py | 6 + .../utils/ctx_managers/sequence_parallel.py | 335 +++++++++ src/axolotl/utils/data/rl.py | 25 +- src/axolotl/utils/data/sft.py | 10 +- src/axolotl/utils/models.py | 3 +- src/axolotl/utils/schemas/config.py | 30 +- src/axolotl/utils/schemas/enums.py | 23 +- tests/e2e/multigpu/patched/test_sp.py | 19 +- tests/e2e/patched/test_sp.py | 130 ++-- 27 files changed, 1448 insertions(+), 455 deletions(-) create mode 100644 src/axolotl/core/trainers/grpo/sampler.py create mode 100644 src/axolotl/utils/ctx_managers/__init__.py create mode 100644 src/axolotl/utils/ctx_managers/sequence_parallel.py diff --git a/_quarto.yml b/_quarto.yml index 17e121c15..463f76b34 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -48,8 +48,23 @@ quartodoc: contents: - core.trainers.base - core.trainers.trl + - core.trainers.mamba + - core.trainers.relora - core.trainers.dpo.trainer - core.trainers.grpo.trainer + - core.trainers.grpo.sampler + - core.trainers.utils + - title: Mixins + desc: Mixin classes for augmenting trainers + contents: + - core.trainers.mixins.optimizer + - core.trainers.mixins.rng_state_loader + - core.trainers.mixins.scheduler + - core.trainers.mixins.sequence_parallel + - title: Context Managers + desc: Context managers for altering trainer behaviors + contents: + - utils.ctx_managers.sequence_parallel - title: Prompt Strategies desc: Prompt formatting strategies contents: @@ -86,7 +101,7 @@ quartodoc: - kernels.swiglu - kernels.quantize - kernels.utils - - title: MonkeyPatches + - title: Monkey Patches desc: Runtime patches for model optimizations contents: - monkeypatch.llama_attn_hijack_flash diff --git a/docs/sequence_parallelism.qmd b/docs/sequence_parallelism.qmd index 20739333a..1bff17ce9 100644 --- a/docs/sequence_parallelism.qmd +++ b/docs/sequence_parallelism.qmd @@ -3,8 +3,6 @@ title: Sequence Parallelism description: Train with long sequences split across multiple GPUs. --- -# Sequence Parallelism - Sequence parallelism is a technique that splits sequences across multiple GPUs, allowing you to train with very long sequences that wouldn't fit on a single GPU. Each GPU processes a different portion of the sequence, and the results are aggregated @@ -27,7 +25,7 @@ To enable sequence parallelism, add the following to your configuration file: sequence_parallel_degree: 4 # Split sequences across 4 GPUs # Optional; strides across the key dimension. Larger values use more memory but should make training faster. heads_k_stride: 1 -# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to +# Optional; one of "varlen_llama3" or "batch_ring". Defaults to # "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise. ring_attn_func: ``` diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index 9dd62f0f7..f944cbd6a 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -14,6 +14,7 @@ from axolotl.utils.data import prepare_dataset from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_processor, load_tokenizer +from axolotl.utils.schemas.enums import RLType from axolotl.utils.tokenization import check_dataset_labels LOG = logging.getLogger(__name__) @@ -133,7 +134,7 @@ def load_preference_datasets( total_num_steps: Optional[int] = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) ) - if cfg.rl == "grpo": + if cfg.rl is RLType.GRPO: total_num_steps = None if cli_args.debug or cfg.debug: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 670561ede..99ab397c7 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -87,7 +87,7 @@ from axolotl.utils.collators import ( ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype -from axolotl.utils.schemas.enums import CustomSupportedOptimizers +from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType try: import torch._dynamo # pylint: disable=ungrouped-imports @@ -353,7 +353,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps - if self.cfg.seed: + if self.cfg.seed is not None: training_arguments_kwargs["seed"] = self.cfg.seed if self.cfg.gradient_checkpointing: @@ -547,8 +547,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): report_to = [] if self.cfg.use_wandb: report_to.append("wandb") - if self.cfg.wandb_name: - training_arguments_kwargs["run_name"] = self.cfg.wandb_name if self.cfg.use_mlflow: report_to.append("mlflow") if self.cfg.use_tensorboard: @@ -821,14 +819,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): data_collator_kwargs = { "padding": True, # True/"longest" is the default } + multiple = 64 if self.cfg.pad_to_sequence_len: - data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil( - self.cfg.sequence_len / 64 + data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil( + self.cfg.sequence_len / multiple ) else: # A100 is best at 64, while others at 8. Let's use the larger so we don't have to check # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html - data_collator_kwargs["pad_to_multiple_of"] = 64 + data_collator_kwargs["pad_to_multiple_of"] = multiple if self.cfg.reward_model: data_collator_kwargs["max_length"] = self.cfg.sequence_len @@ -1034,6 +1033,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): training_args_kwargs["dataloader_prefetch_factor"] = ( self.cfg.dataloader_prefetch_factor ) + + if self.cfg.seed is not None: + training_args_kwargs["seed"] = self.cfg.seed + if self.cfg.gradient_checkpointing: training_args_kwargs["gradient_checkpointing"] = ( self.cfg.gradient_checkpointing @@ -1076,9 +1079,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.use_wandb: training_args_kwargs["run_name"] = self.cfg.wandb_name + training_args_kwargs["sequence_parallel_degree"] = ( + self.cfg.sequence_parallel_degree + ) + training_args_cls = None blocklist_args_kwargs = [] - if self.cfg.rl == "simpo": + if self.cfg.rl is RLType.SIMPO: training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" training_args_kwargs["max_length"] = self.cfg.sequence_len @@ -1086,13 +1093,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha - elif self.cfg.rl == "orpo": + elif self.cfg.rl is RLType.ORPO: training_args_cls = AxolotlORPOConfig training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - elif self.cfg.rl == "kto": + elif self.cfg.rl is RLType.KTO: training_args_cls = AxolotlKTOConfig training_args_kwargs["desirable_weight"] = ( @@ -1106,14 +1113,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - elif self.cfg.rl == "grpo": + elif self.cfg.rl is RLType.GRPO: training_args_cls = GRPOStrategy.get_training_args_class() training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg)) blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs() else: training_args_cls = AxolotlDPOConfig - if self.cfg.rl == "ipo": + if self.cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_completion_length"] = None @@ -1156,33 +1163,35 @@ class HFRLTrainerBuilder(TrainerBuilderBase): def build(self, total_num_steps): training_args = self.build_training_arguments(total_num_steps) - dpo_trainer_kwargs = {} - if self.cfg.rl == "ipo": + trainer_kwargs = {} + if self.cfg.rl is RLType.IPO: if self.cfg.dpo_label_smoothing: - dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing if self.eval_dataset: - dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset + trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: - dpo_trainer_kwargs["peft_config"] = self.peft_config + trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: - dpo_trainer_kwargs["precompute_ref_log_probs"] = ( + trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs ) - if self.cfg.rl == "grpo": - trainer_cls = GRPOStrategy.get_trainer_class() + if self.cfg.rl is RLType.GRPO: + trainer_cls = GRPOStrategy.get_trainer_class( + sequence_parallel=self.cfg.sequence_parallel_degree > 1 + ) trainer_cls_args = [self.model] trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg)) - dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) - elif self.cfg.rl in ["dpo", "ipo"]: + trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg)) + elif self.cfg.rl in [RLType.DPO, RLType.IPO]: trainer_cls = DPOStrategy.get_trainer_class() trainer_cls_args = [self.model, self.model_ref] - elif self.cfg.rl == "orpo": + elif self.cfg.rl is RLType.ORPO: trainer_cls = AxolotlORPOTrainer trainer_cls_args = [self.model] - elif self.cfg.rl in ["kto"]: + elif self.cfg.rl is RLType.KTO: trainer_cls = AxolotlKTOTrainer trainer_cls_args = [self.model] - elif self.cfg.rl in ["simpo"]: + elif self.cfg.rl is RLType.SIMPO: trainer_cls = AxolotlCPOTrainer trainer_cls_args = [self.model] else: @@ -1190,33 +1199,33 @@ class HFRLTrainerBuilder(TrainerBuilderBase): sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): - dpo_trainer_kwargs["tokenizer"] = self.tokenizer + trainer_kwargs["tokenizer"] = self.tokenizer else: - dpo_trainer_kwargs["processing_class"] = self.tokenizer + trainer_kwargs["processing_class"] = self.tokenizer if self.cfg.datasets is not None and ( trainer_cls is DPOStrategy.get_trainer_class() ): - dpo_trainer_kwargs["dataset_tags"] = [ + trainer_kwargs["dataset_tags"] = [ d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() ] - dpo_trainer = trainer_cls( + trainer = trainer_cls( *trainer_cls_args, args=training_args, train_dataset=self.train_dataset, callbacks=self.get_callbacks(), - **dpo_trainer_kwargs, + **trainer_kwargs, ) if self.cfg.fsdp: - ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype) - if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model: - ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype) + ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype) + if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model: + ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype) - dpo_trainer = self.hook_post_create_trainer(dpo_trainer) - for callback in self.get_post_trainer_create_callbacks(dpo_trainer): - dpo_trainer.add_callback(callback) + trainer = self.hook_post_create_trainer(trainer) + for callback in self.get_post_trainer_create_callbacks(trainer): + trainer.add_callback(callback) - return dpo_trainer + return trainer class HFPPOTrainerBuilder(TrainerBuilderBase): diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 32a889af9..2cdc9c195 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -5,7 +5,7 @@ from .base import AxolotlTrainer from .dpo.trainer import AxolotlDPOTrainer -from .grpo.trainer import AxolotlGRPOTrainer +from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer from .mamba import AxolotlMambaTrainer from .relora import ReLoRATrainer from .trl import ( diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index ab9735adc..2f0ce6894 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -373,15 +373,13 @@ class AxolotlTrainer( num_items_in_batch=num_items_in_batch, ) - loss = super().compute_loss( + return super().compute_loss( model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch, ) - return loss - @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 2d6835cf7..603fdf0b6 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -1,14 +1,11 @@ -""" -DPO Specific Strategy for training -""" +"""DPO Specific Strategy for training""" from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer +from axolotl.utils.schemas.enums import RLType class DPOStrategy: - """ - Strategy for DPO training - """ + """Strategy for DPO training""" @classmethod def get_trainer_class(cls): @@ -23,7 +20,7 @@ class DPOStrategy: @classmethod def set_training_args_kwargs(cls, cfg): training_args_kwargs = {} - if cfg.rl == "ipo": + if cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = "ipo" training_args_kwargs["max_length"] = cfg.sequence_len training_args_kwargs["max_completion_length"] = None diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 078fdcb22..f4685893b 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -1,37 +1,41 @@ -""" -GRPO Specific Strategy for training -""" +"""GRPO Specific Strategy for training""" import importlib import inspect import logging +from typing import Any from trl.trainer.grpo_trainer import RewardFunc -from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer +from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig +from axolotl.core.trainers.grpo.trainer import ( + AxolotlGRPOSequenceParallelTrainer, + AxolotlGRPOTrainer, +) +from axolotl.utils.dict import DictDefault from axolotl.utils.schemas.trl import TRLConfig -LOG = logging.getLogger("axolotl") +LOG = logging.getLogger(__name__) class GRPOStrategy: - """ - Strategy for GRPO training - """ + """Strategy for GRPO training""" @classmethod - def get_trainer_class(cls): + def get_trainer_class( + cls, sequence_parallel: bool + ) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]: + if sequence_parallel: + return AxolotlGRPOSequenceParallelTrainer return AxolotlGRPOTrainer @classmethod - def get_training_args_class(cls): - from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig - + def get_training_args_class(cls) -> type[AxolotlGRPOConfig]: return AxolotlGRPOConfig @classmethod - def set_training_args_kwargs(cls, cfg): - grpo_args_kwargs = {} + def set_training_args_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: + grpo_args_kwargs: dict[str, Any] = {} if not hasattr(cfg, "trl") or not cfg.trl: return grpo_args_kwargs @@ -40,8 +44,8 @@ class GRPOStrategy: if trl.use_vllm: grpo_args_kwargs["use_vllm"] = trl.use_vllm - grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host - grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port + grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined] + grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined] if trl.vllm_server_timeout: grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout if trl.vllm_guided_decoding_regex: @@ -102,17 +106,18 @@ class GRPOStrategy: return grpo_args_kwargs @classmethod - def set_trainer_args(cls, cfg): + def set_trainer_args(cls, cfg: DictDefault) -> list[Any]: trainer_args = [] if cfg.trl and cfg.trl.reward_funcs: reward_funcs = [] for reward_func_fqn in cfg.trl.reward_funcs: reward_funcs.append(cls.get_reward_func(reward_func_fqn)) trainer_args.append(reward_funcs) + return trainer_args @classmethod - def set_trainer_kwargs(cls, cfg): + def set_trainer_kwargs(cls, cfg: DictDefault) -> dict[str, Any]: trainer_kwargs = {} if cfg.trl and cfg.trl.reward_processing_classes: trainer_kwargs["reward_processing_classes"] = ( @@ -126,7 +131,7 @@ class GRPOStrategy: return None @classmethod - def get_blocklist_args_kwargs(cls): + def get_blocklist_args_kwargs(cls) -> list[str]: return ["dataset_num_proc"] @classmethod @@ -137,13 +142,13 @@ class GRPOStrategy: Args: reward_func_fqn (str): Fully qualified name of the reward function (e.g. r1_grpo.gsm8k_transform), or a HF hub path to the reward model. - Raises: - ValueError: If the reward function does not accept at least two arguments. Returns: RewardFunc: A callable that accepts prompts and completions and returns rewards, or a path to a reward model. + Raises: + ValueError: If the reward function does not accept at least two arguments. """ try: # use importlib to dynamically load the reward function from the module diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index 5460edca9..76be88c89 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -11,6 +11,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins @dataclass class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): - """ - Axolotl GRPO Config for GRPO training - """ + """Axolotl GRPO Config for GRPO training""" diff --git a/src/axolotl/core/trainers/grpo/sampler.py b/src/axolotl/core/trainers/grpo/sampler.py new file mode 100644 index 000000000..ebc6e19e2 --- /dev/null +++ b/src/axolotl/core/trainers/grpo/sampler.py @@ -0,0 +1,172 @@ +"""Repeat random sampler (similar to the one implemented in +https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py) that adds +sequence parallelism functionality; i.e., duplicating data across ranks in the same +sequence parallel group. +""" + +from typing import Iterator, Sized + +import torch +from torch.utils.data import Sampler + + +class SequenceParallelRepeatRandomSampler(Sampler): + """Sampler for GRPO training with sequence parallelism. + + This sampler ensures: + - Ranks in the same sequence parallel (SP) group receive identical data. + - Each index is repeated multiple times for sampling different completions. + - Entire batches are repeated for reuse in multiple updates. + - Data is properly distributed across SP groups. + + In the table below, the values represent dataset indices. Each SP group has + `sequence_parallel_degree = 2` GPUs working together on the same data. There are 2 + SP groups (SP0 and SP1), with `world_size = 4` total GPUs. + + Sequence Parallel Groups + | SP0 | SP1 | + | GPU 0 | GPU 1 | GPU 2 | GPU 3 | + global_step step <---> mini_repeat_count=3 + <----------> batch_size=2 per SP group + grad_accum=2 ▲ ▲ 0 0 [0 0 0 1 1 1] [2 2 2 3 3 3] <- SP groups get different data + ▼ | 0 1 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Same data for each SP group GPU + | + | 1 2 [0 0 0 1 1 1] [2 2 2 3 3 3] <- Repeat same indices for iterations + num_iterations=2 ▼ 1 3 [0 0 0 1 1 1] [2 2 2 3 3 3] <- When using gradient accumulation + + 2 4 [4 4 4 5 5 5] [6 6 6 7 7 7] <- New batch of data indices + 2 5 [4 4 4 5 5 5] [6 6 6 7 7 7] + ... + + Args: + dataset: Dataset to sample from. + mini_repeat_count: How many times to repeat each sample immediately. + world_size: Total number of processes. + rank: Rank of current process. + batch_size: Number of samples per batch. + repeat_count: How many times to repeat the full sampling process. + sequence_parallel_degree: Number of ranks in a sequence parallel group. + shuffle: Whether to shuffle the dataset. + seed: Random seed for shuffling. + drop_last: Whether to drop the last incomplete batch. + """ + + def __init__( + self, + dataset: Sized, + mini_repeat_count: int, + world_size: int, + rank: int, + batch_size: int = 1, + repeat_count: int = 1, + sequence_parallel_degree: int = 1, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ): + self.dataset = dataset + self.mini_repeat_count = mini_repeat_count + self.batch_size = batch_size + self.repeat_count = repeat_count + self.shuffle = shuffle + self.seed = seed + self.drop_last = drop_last + self.epoch = 0 + + self.world_size = world_size + self.rank = rank + + # Sequence parallelism parameters + self.sequence_parallel_degree = sequence_parallel_degree + self.num_sp_groups = world_size // sequence_parallel_degree + self.sp_group_id = rank // sequence_parallel_degree + + # Adjust dataset size for distributed sampling + self.num_samples = len(self.dataset) + self.total_size = self.num_samples + + # Calculate effective number of samples per SP group + if ( + self.drop_last + and self.total_size % (self.num_sp_groups * self.batch_size) != 0 + ): + # Drop last incomplete batch if drop_last is True + self.num_samples_per_sp_group = ( + self.total_size // self.batch_size // self.num_sp_groups + ) * self.batch_size + else: + # Round up to include last batch if drop_last is False + self.num_samples_per_sp_group = ( + (self.total_size + self.batch_size * self.num_sp_groups - 1) + // (self.batch_size * self.num_sp_groups) + * self.batch_size + ) + + if shuffle: + self.generator = torch.Generator() + self.generator.manual_seed(seed) + + def __iter__(self) -> Iterator[int]: + """Creates iterator over dataset indices. + + Returns: + Iterator that yields indices into the dataset. + """ + # Deterministically shuffle based on epoch and seed + if self.shuffle: + indices = torch.randperm( + self.num_samples, generator=self.generator + ).tolist() + else: + indices = list(range(self.num_samples)) + + # Add extra samples to make it evenly divisible by batch_size + if len(indices) % self.batch_size != 0: + padding = indices[: self.batch_size - len(indices) % self.batch_size] + indices += padding + + # Subsample based on SP group ID + # Each SP group gets distinct batches of data + batch_indices = [] + for i in range(0, len(indices), self.batch_size * self.num_sp_groups): + start_idx = i + self.sp_group_id * self.batch_size + end_idx = min(start_idx + self.batch_size, len(indices)) + if start_idx < len(indices): + for j in range(self.batch_size): + if start_idx + j < end_idx: + batch_indices.append(indices[start_idx + j]) + + # Make sure batch_indices is exactly batch_size * num_batches_per_sp_group + if self.drop_last: + num_batches_per_sp_group = self.num_samples_per_sp_group // self.batch_size + target_len = self.batch_size * num_batches_per_sp_group + if len(batch_indices) > target_len: + batch_indices = batch_indices[:target_len] + + # Apply the GRPO repeat pattern + final_indices = [] + for _ in range(self.repeat_count): + for idx in batch_indices: + for _ in range(self.mini_repeat_count): + final_indices.append(idx) + + return iter(final_indices) + + def __len__(self) -> int: + """Returns the total length of the iterable including repetitions. + + Returns: + Total number of samples. + """ + # Total length including all repetitions + return ( + self.num_samples_per_sp_group * self.mini_repeat_count * self.repeat_count + ) + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + Args: + epoch: Epoch number to use for shuffling. + """ + self.epoch = epoch diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 25aafa6a7..bc3d140b1 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,23 +1,63 @@ -""" -Axolotl GRPO trainer -""" +"""Axolotl GRPO trainers (with and without sequence parallelism handling)""" +# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member + +import warnings from contextlib import nullcontext +from typing import Any -from accelerate.utils import is_deepspeed_available, is_peft_model +import datasets +import torch +import torch.distributed as dist +import torch.utils.data +from accelerate.utils import ( + broadcast_object_list, + gather, + gather_object, + is_peft_model, +) +from datasets import Dataset, IterableDataset +from torch import nn +from torch.utils.data import ( + BatchSampler, + DataLoader, + Sampler, +) +from transformers import ( + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, +) +from transformers.trainer_utils import seed_worker +from transformers.utils import is_peft_available from trl import GRPOTrainer -from trl.extras.profiling import profiling_decorator +from trl.data_utils import ( + apply_chat_template, + is_conversational, + maybe_apply_chat_template, +) +from trl.extras.profiling import profiling_context, profiling_decorator +from trl.import_utils import is_deepspeed_available +from trl.models import unwrap_model_for_generation +from trl.trainer.grpo_config import GRPOConfig +from trl.trainer.grpo_trainer import RewardFunc, nanstd +from trl.trainer.utils import pad +from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.monkeypatch.attention.ring_attn.patch import get_ring_attn_group + +if is_peft_available(): + # pylint: disable=unused-import + from peft import PeftConfig if is_deepspeed_available(): import deepspeed class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): - """ - Extend the base GRPOTrainer for axolotl helpers - """ + """Extend the base GRPOTrainer for axolotl helpers""" _tag_names = ["trl", "grpo", "axolotl"] @@ -67,3 +107,600 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): # Reset cache on main process if self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() + + +class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): + """Extend the base GRPOTrainer for sequence parallelism handling""" + + def __init__( + self, + model: str | PreTrainedModel, + reward_funcs: RewardFunc | list[RewardFunc], + args: GRPOConfig | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: ( + Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None + ) = None, + processing_class: PreTrainedTokenizerBase | None = None, + reward_processing_classes: ( + PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None + ) = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[ + torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None + ] = (None, None), + peft_config: "PeftConfig | None" = None, + ): + # First call the superclass constructor with all arguments + super().__init__( + model=model, + reward_funcs=reward_funcs, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_classes=reward_processing_classes, + callbacks=callbacks, + optimizers=optimizers, + peft_config=peft_config, + ) + + # Get number of SP groups (number of processes divided by SP degree) + num_processes = self.accelerator.num_processes + num_sp_groups = num_processes // self.args.sequence_parallel_degree + + # Calculate batch size per SP group (not per process) + sp_group_batch_size = self.args.per_device_train_batch_size * num_sp_groups + possible_values = [ + n_gen + for n_gen in range(2, sp_group_batch_size + 1) + if (sp_group_batch_size) % n_gen == 0 + ] + + if self.num_generations not in possible_values: + raise ValueError( + f"The batch size per SP group ({num_sp_groups} x " + f"{self.args.per_device_train_batch_size}) must be evenly divisible by " + f"the number of generations per prompt ({self.num_generations}). Given " + "the current configuration, the valid values for the number of " + f"generations are: {possible_values}." + ) + + if self.args.eval_strategy != "no": + # If sequence parallelism is enabled, calculate batch size per SP group + sp_group_eval_batch_size = args.per_device_eval_batch_size * num_sp_groups # type: ignore[union-attr] + possible_values = [ + n_gen + for n_gen in range(2, sp_group_eval_batch_size + 1) + if (sp_group_eval_batch_size) % n_gen == 0 + ] + + if self.num_generations not in possible_values: + raise ValueError( + f"With sequence parallelism (degree {self.args.sequence_parallel_degree}), " + f"the eval batch size per SP group ({num_sp_groups} x {self.args.per_device_eval_batch_size}) " + f"must be evenly divisible by the number of generations per prompt " + f"({self.num_generations}). Given the current eval batch size, " + f"the valid values for the number of generations are: {possible_values}." + ) + + # Initialize the SP group + self.sp_group = get_ring_attn_group() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.local_rank = dist.get_rank(group=self.sp_group) + self.local_world_size = dist.get_world_size(group=self.sp_group) + + def _get_train_sampler(self) -> Sampler: + effective_batch_size = ( + self.args.per_device_train_batch_size + * self.world_size + * self.args.gradient_accumulation_steps + ) + + return SequenceParallelRepeatRandomSampler( + dataset=self.train_dataset, + mini_repeat_count=self.num_generations, + world_size=self.world_size, + rank=self.rank, + batch_size=effective_batch_size + // self.num_generations + // self.args.sequence_parallel_degree, + repeat_count=self.num_iterations * self.args.gradient_accumulation_steps, + sequence_parallel_degree=self.args.sequence_parallel_degree, + shuffle=True, + seed=self.args.seed, + drop_last=True, + ) + + def _create_dataloader_params(self, is_eval=False, custom_batch_size=None): + """Create common dataloader parameters for train or eval.""" + batch_size = custom_batch_size or ( + self.args.eval_batch_size if is_eval else self._train_batch_size + ) + + params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + + # Add persistent workers only for training + if not is_eval and hasattr(self.args, "dataloader_persistent_workers"): + params["persistent_workers"] = self.args.dataloader_persistent_workers + + # Add prefetch factor if specified + if self.args.dataloader_prefetch_factor: + params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return params + + def _prepare_dataloader( + self, dataset, sampler, is_eval=False, custom_batch_size=None + ): + """Prepare a dataloader with the given dataset and sampler.""" + # Get base parameters + dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size) + + # Add sampler configuration + if not isinstance(dataset, torch.utils.data.IterableDataset): + if isinstance(sampler, BatchSampler): + # batch_size and batch_sampler are mutually exclusive + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + if not is_eval: + dataloader_params["worker_init_fn"] = seed_worker + + # Create the dataloader + dataloader = DataLoader(dataset, **dataloader_params) + + if self.args.sample_packing and ( + (not is_eval and not self.args.pretraining) + or (is_eval and self.args.eval_sample_packing is not False) + ): + self.accelerator.even_batches = False + + # Return unprepared dataloader if using sequence parallelism + # TODO(djsaunde): We might be able to use `accelerate`'s dataloader preparation + # if we use `dispatch_batches` and `slice_fn_for_dispatch` properly (i.e., + # slice each batch along the sequence dimension). + if self.args.sequence_parallel_degree > 1: + return dataloader + + # Otherwise prepare with accelerator + return self.accelerator.prepare_data_loader(dataloader) + + def get_train_dataloader(self) -> DataLoader: + """Get dataloader for training""" + train_dataset = self.train_dataset + # pylint: disable=access-member-before-definition + data_collator = self.data_collator # type: ignore + + # Handle dataset preprocessing + if isinstance(train_dataset, datasets.Dataset): + # Add debug print before any modifications + if self.args.sample_packing and not self.args.pretraining: + train_dataset = train_dataset.remove_columns(["length"]) + if not self.args.sample_packing or self.args.pretraining: + train_dataset = self._remove_unused_columns( + train_dataset, description="training" + ) + else: + self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init + data_collator, + description="training", + ) + + # Get sampler and create dataloader + sampler = self._get_train_sampler() + dataloader = self._prepare_dataloader(train_dataset, sampler, is_eval=False) + + return dataloader + + def _generate_and_score_completions( + self, inputs: list[dict[str, torch.Tensor | Any]] + ) -> dict[str, torch.Tensor | Any]: + device = self.accelerator.device + mode = "eval" if self.control.should_evaluate else "train" + + prompts = [x["prompt"] for x in inputs] + prompts_text = [ + maybe_apply_chat_template(example, self.processing_class)["prompt"] + for example in inputs + ] + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + ) + prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) + prompt_ids, prompt_mask = ( + prompt_inputs["input_ids"], + prompt_inputs["attention_mask"], + ) + + if self.max_prompt_length is not None: + prompt_ids = prompt_ids[:, -self.max_prompt_length :] + prompt_mask = prompt_mask[:, -self.max_prompt_length :] + + # Generate completions using either vLLM or regular generation + if self.args.use_vllm: + # First, have main process load weights if needed + # pylint: disable=access-member-before-definition + if self.state.global_step != self._last_loaded_step: # type: ignore[has-type] + self._move_model_to_vllm() + # pylint: disable=attribute-defined-outside-init + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: + if self.args.sequence_parallel_degree > 1: + # Calculate sequence parallel group information + world_size = self.accelerator.num_processes + sequence_parallel_degree = self.args.sequence_parallel_degree + num_sp_groups = world_size // sequence_parallel_degree + + # Since processes in the same SP group have the same prompts, we need to ensure + # we only take one copy of each prompt from each SP group + ordered_set_of_prompts = [] + for sp_group_id in range(num_sp_groups): + # Get the first process from each SP group (typically the group leader) + group_leader_rank = sp_group_id * sequence_parallel_degree + + # Extract prompts from this SP group, accounting for num_generations duplicates + # We only need prompts from one rank in each SP group + group_prompts = all_prompts_text[ + group_leader_rank + * len(prompts_text) : (group_leader_rank + 1) + * len(prompts_text) : self.num_generations + ] + + ordered_set_of_prompts.extend(group_prompts) + else: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[ + :: self.num_generations * self.args.sequence_parallel_degree + ] + + with profiling_context(self, "vLLM.generate"): + completion_ids = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) + else: + completion_ids = [None] * ( + len(all_prompts_text) // self.args.sequence_parallel_degree + ) + + # Broadcast the completions from the main process to all processes + completion_ids = broadcast_object_list(completion_ids, from_process=0) + + # Determine the appropriate slice based on sequence parallelism + if self.args.sequence_parallel_degree > 1: + # Calculate SP group ID (which group of ranks this rank belongs to) + sp_group_id = self.accelerator.process_index // self.local_world_size + + # Calculate the start index for this SP group + sp_group_start = sp_group_id * len(prompts) * self.local_world_size + + # All ranks in the same SP group get the same data slice + process_slice = slice( + sp_group_start, + sp_group_start + len(prompts), + ) + completion_ids = completion_ids[process_slice] + else: + # Original behavior for non-sequence parallel case + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = completion_ids[process_slice] + + # Pad the completions, and concatenate them with the prompts + completion_ids = [ + torch.tensor(ids, device=device) for ids in completion_ids + ] + completion_ids = pad( + completion_ids, padding_value=self.processing_class.pad_token_id + ) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + else: + # Regular generation path + with unwrap_model_for_generation( + self.model_wrapped, + self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation, + ) as unwrapped_model: + prompt_completion_ids = unwrapped_model.generate( + prompt_ids, + attention_mask=prompt_mask, + generation_config=self.generation_config, + ) + + # Compute prompt length and extract completion ids + prompt_length = prompt_ids.size(1) + prompt_ids = prompt_completion_ids[:, :prompt_length] + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full( + (is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device + ) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand( + is_eos.size(0), -1 + ) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.args.mask_truncated_completions: + truncated_completions = ~is_eos.any(dim=1) + completion_mask = ( + completion_mask * (~truncated_completions).unsqueeze(1).int() + ) + + # Concatenate prompt_mask with completion_mask for logit computation + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size( + 1 + ) # we only need to compute the logits for the completion tokens + batch_size = ( + self.args.per_device_train_batch_size + if mode == "train" + else self.args.per_device_eval_batch_size + ) + + with torch.no_grad(): + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's + # computation here, and use per_token_logps.detach() instead. + if self.num_iterations > 1: + old_per_token_logps = self._get_per_token_logps( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + ) + else: + old_per_token_logps = None + + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps( + self.ref_model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps( + self.model, + prompt_completion_ids, + attention_mask, + logits_to_keep, + batch_size, + ) + + # Decode the generated completions + completions_text = self.processing_class.batch_decode( + completion_ids, skip_special_tokens=True + ) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = ( + prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + ) + completions.append( + [{"role": "assistant", "content": bootstrap + completion}] + ) + else: + completions = completions_text + + rewards_per_func = torch.zeros( + len(prompts), len(self.reward_funcs), device=device + ) + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip( + self.reward_funcs, + self.reward_processing_classes, + self.reward_func_names, + ) + ): + with profiling_context(self, reward_func_name): + if isinstance( + reward_func, nn.Module + ): # Module instead of PretrainedModel for compat with compiled models + if is_conversational(inputs[0]): + messages = [ + {"messages": p + c} for p, c in zip(prompts, completions) + ] + texts = [ + apply_chat_template(x, reward_processing_class)["text"] + for x in messages + ] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + text=texts, + return_tensors="pt", + padding=True, + padding_side="right", + add_special_tokens=False, + ) + reward_inputs = Trainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[ + :, 0 + ] # Shape (B*G,) + else: + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + keys = [ + key for key in inputs[0] if key not in ["prompt", "completion"] + ] + reward_kwargs = { + key: [example[key] for example in inputs] for key in keys + } + output_reward_func = reward_func( + prompts=prompts, completions=completions, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [ + reward if reward is not None else torch.nan + for reward in output_reward_func + ] + + rewards_per_func[:, i] = torch.tensor( + output_reward_func, dtype=torch.float32, device=device + ) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = ( + torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + ) + row_reward_kwargs = { + key: value[nan_row_idx] for key, value in reward_kwargs.items() + } + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + warnings.warn( + f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " + "Please ensure that at least one reward function returns a valid reward." + ) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + + # Apply weights to each reward function's output and sum + rewards = ( + rewards_per_func * self.reward_weights.to(device).unsqueeze(0) + ).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave( + self.num_generations, dim=0 + ) + std_grouped_rewards = std_grouped_rewards.repeat_interleave( + self.num_generations, dim=0 + ) + advantages = rewards - mean_grouped_rewards + if self.args.scale_rewards: + advantages = advantages / (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + if self.args.sequence_parallel_degree > 1: + # Calculate SP group ID (which group of ranks this rank belongs to) + sp_group_id = self.accelerator.process_index // self.local_world_size + + # Calculate the start index for this SP group + sp_group_start = sp_group_id * len(prompts) * self.local_world_size + + # All ranks in the same SP group get the same data slice + process_slice = slice( + sp_group_start, + sp_group_start + len(prompts), + ) + else: + # Original behavior for non-sequence parallel case + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + # Log the metrics + if mode == "train": + self._total_train_tokens += ( + self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() + ) + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + # log completion lengths, mean, min, max + agg_completion_mask = self.accelerator.gather_for_metrics( + completion_mask.sum(1) + ) + self._metrics[mode]["completions/mean_length"].append( + agg_completion_mask.float().mean().item() + ) + self._metrics[mode]["completions/min_length"].append( + agg_completion_mask.float().min().item() + ) + self._metrics[mode]["completions/max_length"].append( + agg_completion_mask.float().max().item() + ) + + # identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1)) + term_completion_mask = agg_completion_mask[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_mask) / len( + agg_completion_mask + ) + self._metrics[mode]["completions/clipped_ratio"].append( + clipped_completions_ratio + ) + if len(term_completion_mask) == 0: + # edge case where no completed sequences are found + term_completion_mask = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append( + term_completion_mask.float().mean().item() + ) + self._metrics[mode]["completions/min_terminated_length"].append( + term_completion_mask.float().min().item() + ) + self._metrics[mode]["completions/max_terminated_length"].append( + term_completion_mask.float().max().item() + ) + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards) + self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) + + # Log prompt and completion texts + self._textual_logs["prompt"].extend(gather_object(prompts_text)) + self._textual_logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + + return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "old_per_token_logps": old_per_token_logps, + "ref_per_token_logps": ref_per_token_logps, + } diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 6e4b3e4d0..44751b465 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -6,4 +6,4 @@ from .optimizer import OptimizerMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin -from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin +from .sequence_parallel import SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 362acb88e..0f30458cd 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,85 +1,13 @@ -""" -Module for Axolotl trainer sequence parallelism mixin and training context manager -""" +"""Module for Axolotl trainer sequence parallelism mixin""" -import functools -import logging - -import torch import torch.distributed as dist from datasets import Dataset -from torch import nn from torch.utils.data import DistributedSampler, Sampler -from torch.utils.hooks import RemovableHandle from axolotl.monkeypatch.attention.ring_attn import ( - RingAttnFunc, get_ring_attn_group, - update_ring_attn_params, ) -LOG = logging.getLogger(__name__) - - -def apply_sequence_parallelism( - batch: dict[str, torch.Tensor], - local_rank: int, - local_world_size: int, - ring_attn_func: RingAttnFunc, -) -> dict[str, torch.Tensor]: - """ - Apply sequence parallelism slicing to a batch. - - Args: - batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) - local_rank: Local rank in the sequence parallel group - local_world_size: World size of the sequence parallel group - ring_attn_func: The ring attention function to use - - Returns: - Sliced batch dictionary. - """ - # Update ring attention params if needed - if batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=batch["position_ids"]) - - # Slice batch for sequence parallel processing - total_seq_len = batch["input_ids"].size(1) - for key in batch: - if ( - key in batch - and isinstance(batch[key], torch.Tensor) - and batch[key].dim() > 1 - and batch[key].size(1) == total_seq_len - ): - - if ring_attn_func in [ - RingAttnFunc.VARLEN_LLAMA3, - RingAttnFunc.BATCH_RING, - ]: - # Split in sequential fashion and grab this rank's chunk - batch[key] = ( - batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() - ) - elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - chunks = batch[key].chunk(2 * local_world_size, dim=1) - - # Take rank's chunk and opposing chunk for zigzag pattern - selected_chunks = [ - chunks[local_rank], - chunks[2 * local_world_size - local_rank - 1], - ] - batch[key] = torch.cat(selected_chunks, dim=1).contiguous() - elif ring_attn_func is RingAttnFunc.BATCH_STRIPE: - # Split into striped data and stack - tensor = torch.stack( - batch[key].split(local_world_size, dim=1), - dim=1, - ).transpose(1, 2) - batch[key] = tensor[:, local_rank].contiguous() - - return batch - class SequenceParallelMixin: """ @@ -157,157 +85,3 @@ class SequenceParallelMixin: return self._create_sequence_parallel_sampler( eval_dataset, shuffle=False, is_eval=True ) - - -class SequenceParallelContextManager: - """ - Context manager for sequence parallelism operations. - - This class provides a context that will automatically apply sequence parallelism - during model forward passes using a pre-forward hook, and gather outputs from - across the sequence parallelism group using a post-forward hook. - """ - - def __init__( - self, - model: nn.Module, - sequence_parallel_degree: int, - ring_attn_func: RingAttnFunc, - ): - self.model = model - self.sequence_parallel_degree = sequence_parallel_degree - self.ring_attn_func = ring_attn_func - self.process_group = get_ring_attn_group() - - # Initialize sequence parallel group details - self.local_rank = dist.get_rank(self.process_group) - self.local_world_size = dist.get_world_size(self.process_group) - - # Will store hook handles for removal - self.hook_handles: list[RemovableHandle] = [] - - # Create a partially applied version of the apply_sequence_parallelism function - # with pre-configured params - self.apply_sequence_parallelism = functools.partial( - apply_sequence_parallelism, - local_rank=self.local_rank, - local_world_size=self.local_world_size, - ring_attn_func=self.ring_attn_func, - ) - - def __enter__(self): - # Forward pre-hook to apply sequence parallelism - def sequence_parallel_pre_hook(_, args, kwargs): - # Apply sequence parallelism to kwargs - kwargs = self.apply_sequence_parallelism(batch=kwargs) - return args, kwargs - - # Forward post-hook to gather outputs - def sequence_parallel_post_hook(_, __, output): - # Gather the sharded outputs - return self.gather_outputs(output) - - # Register both hooks - self.hook_handles.append( - self.model.register_forward_pre_hook( - sequence_parallel_pre_hook, with_kwargs=True - ) - ) - self.hook_handles.append( - self.model.register_forward_hook(sequence_parallel_post_hook) - ) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - # Remove all hooks - for handle in self.hook_handles: - handle.remove() - self.hook_handles = [] - - def gather_outputs(self, output): - """Gather sharded outputs from all ranks and reconstruct the full tensor.""" - # Handle different output formats (dict, tensor, etc.) - if isinstance(output, dict): - gathered_output = {} - for key, value in output.items(): - if isinstance(value, torch.Tensor) and value.dim() > 1: - # Gather logits or other sequence-sharded tensors - gathered_value = self.gather_tensor(value) - gathered_output[key] = gathered_value - else: - gathered_value = value.clone() - dist.all_reduce( - gathered_value, op=dist.ReduceOp.SUM, group=self.process_group - ) - gathered_output[key] = gathered_value - return gathered_output - if isinstance(output, torch.Tensor): - return self.gather_tensor(output) - - return output - - def gather_tensor(self, tensor): - """Gather a sharded tensor from all ranks.""" - # Prepare tensors for all_gather - world_size = self.local_world_size - - # Create list to store tensors from all ranks - gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] - - # All-gather operation - dist.all_gather(gathered_tensors, tensor, group=self.process_group) - - # Concatenate along sequence dimension (typically dim=1) - if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]: - # Simple concatenation for standard sharding - return torch.cat(gathered_tensors, dim=1) - - if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - # Each rank has a pattern of (rank, world_size*2-rank-1) - reconstituted_tensors = [None] * (world_size * 2) - - # First, split each gathered tensor into its two chunks - for rank, gathered_tensor in enumerate(gathered_tensors): - # Each tensor contains two chunks in the sequence dimension - chunk_size = gathered_tensor.size(1) // 2 - chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1) - - # Place chunks in their original positions - reconstituted_tensors[rank] = chunk1 - reconstituted_tensors[world_size * 2 - rank - 1] = chunk2 - - # Concatenate the reconstituted tensors in the correct order - return torch.cat(reconstituted_tensors, dim=1) - - # Otherwise, RingAttnFunc.BATCH_STRIPE - # In striping, each rank has every world_size-th slice - batch_size = tensor.size(0) - hidden_dim = tensor.size(-1) - - # First, determine the full sequence length - total_seq_len = 0 - for t in gathered_tensors: - total_seq_len += t.size(1) - - # Create a tensor to hold the unstriped result - result = torch.zeros( - batch_size, - total_seq_len, - hidden_dim, - dtype=tensor.dtype, - device=tensor.device, - ) - - # For each rank's tensor, distribute its slices to the correct positions - for rank, gathered_tensor in enumerate(gathered_tensors): - # The rank's tensor contains every world_size-th slice - # starting from its rank position - seq_len = gathered_tensor.size(1) - for i in range(seq_len): - # Calculate the position in the full tensor - pos = i * world_size + rank - if pos < total_seq_len: - result[:, pos] = gathered_tensor[:, i] - - return result diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 3fe32f507..0b14e7661 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -9,7 +9,7 @@ from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig -from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc +from axolotl.utils.schemas.enums import RingAttnFunc @dataclass diff --git a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py index 055607e92..a50ad456e 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/__init__.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/__init__.py @@ -4,7 +4,6 @@ # flake8: noqa from .patch import ( - RingAttnFunc, get_ring_attn_group, register_ring_attn, set_ring_attn_group, diff --git a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py index a88c9f6f1..e556ba5e3 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py @@ -16,11 +16,7 @@ import torch import torch.distributed as dist import transformers import transformers.modeling_flash_attention_utils -from ring_flash_attn import ( - ring_flash_attn_func, - stripe_flash_attn_func, - zigzag_ring_flash_attn_func, -) +from ring_flash_attn import ring_flash_attn_func from ring_flash_attn.adapters.hf_adapter import check_params from transformers.modeling_flash_attention_utils import ( _flash_supports_window_size, @@ -28,12 +24,12 @@ from transformers.modeling_flash_attention_utils import ( ) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc +from axolotl.utils.schemas.enums import RingAttnFunc RING_ATTN_FUNC_MAPPING = { - RingAttnFunc.BATCH_RING: ring_flash_attn_func, - RingAttnFunc.BATCH_ZIGZAG: zigzag_ring_flash_attn_func, - RingAttnFunc.BATCH_STRIPE: stripe_flash_attn_func, + RingAttnFunc.BATCH_RING: torch.compile(ring_flash_attn_func), + # RingAttnFunc.BATCH_ZIGZAG: torch.compile(zigzag_ring_flash_attn_func), + # RingAttnFunc.BATCH_STRIPE: torch.compile(stripe_flash_attn_func), } diff --git a/src/axolotl/monkeypatch/attention/ring_attn/patch.py b/src/axolotl/monkeypatch/attention/ring_attn/patch.py index fa03bd174..8cbba338a 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/attention/ring_attn/patch.py @@ -6,13 +6,12 @@ package, specifically the `hf_adapter.substitute_hf_flash_attn` function to patc their sequence parallel version of Flash Attention 2. """ -from enum import Enum - import torch import torch.distributed as dist from accelerate.logging import get_logger from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.utils.schemas.enums import RingAttnFunc LOG = get_logger(__name__) @@ -41,17 +40,6 @@ def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): RING_ATTN_GROUP = ring_attn_group -class RingAttnFunc(str, Enum): - """Enum class for supported `ring-flash-attn` implementations""" - - # VARLEN_RING = "varlen_ring" - # VARLEN_ZIGZAG = "varlen_zigzag" - VARLEN_LLAMA3 = "varlen_llama3" - BATCH_RING = "batch_ring" - BATCH_ZIGZAG = "batch_zigzag" - BATCH_STRIPE = "batch_stripe" - - def register_ring_attn( sequence_parallel_degree: int, heads_k_stride: int | None, @@ -117,11 +105,7 @@ def register_ring_attn( substitute_hf_flash_attn( process_group=get_ring_attn_group(), heads_k_stride=heads_k_stride or 1 ) - elif ring_attn_func in [ - RingAttnFunc.BATCH_RING, - RingAttnFunc.BATCH_ZIGZAG, - RingAttnFunc.BATCH_STRIPE, - ]: + elif ring_attn_func is RingAttnFunc.BATCH_RING: from axolotl.monkeypatch.attention.ring_attn.adapters.batch import ( substitute_hf_flash_attn, ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 2adf28fdf..90ab10e9f 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -7,7 +7,7 @@ import os import signal import sys import weakref -from contextlib import nullcontext +from contextlib import ExitStack from pathlib import Path from typing import Any, Dict @@ -27,14 +27,13 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder -from axolotl.core.trainers.mixins.sequence_parallel import ( - SequenceParallelContextManager, -) from axolotl.integrations.base import PluginManager +from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer +from axolotl.utils.schemas.enums import RLType from axolotl.utils.trainer import setup_trainer try: @@ -107,7 +106,7 @@ def setup_reference_model( Reference model if needed for RL training, `None` otherwise. """ model_ref = None - if cfg.rl and cfg.rl != "orpo": + if cfg.rl and cfg.rl != RLType.ORPO: if cfg.adapter and not cfg.rl_adapter_ref_model: # use built-in trl autounwrap LOG.debug("Passing model_ref: None to RL trainer") @@ -188,28 +187,32 @@ def execute_training( trainer: The configured trainer object. resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ - # Define the context managers to use - flash_context = ( - torch.backends.cuda.sdp_kernel( - enable_flash=True, - enable_math=True, - enable_mem_efficient=True, - ) - if cfg.flash_optimum - else nullcontext() - ) - sequence_parallel_context = ( - SequenceParallelContextManager( - model=trainer.model, - sequence_parallel_degree=cfg.sequence_parallel_degree, - ring_attn_func=cfg.ring_attn_func, - ) - if cfg.sequence_parallel_degree > 1 - else nullcontext() - ) + with ExitStack() as stack: + # Define the context managers to use + if cfg.flash_optimum: + stack.enter_context( + torch.backends.cuda.sdp_kernel( + enable_flash=True, + enable_math=True, + enable_mem_efficient=True, + ) + ) - LOG.info("Starting trainer...") - with flash_context, sequence_parallel_context: + if cfg.sequence_parallel_degree > 1: + models = [trainer.model] + if hasattr(trainer, "ref_model"): + models.append(trainer.ref_model) + + stack.enter_context( + SequenceParallelContextManager( + models=models, + sequence_parallel_degree=cfg.sequence_parallel_degree, + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + ring_attn_func=cfg.ring_attn_func, + ) + ) + + LOG.info("Starting trainer...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/ctx_managers/__init__.py b/src/axolotl/utils/ctx_managers/__init__.py new file mode 100644 index 000000000..e544621b5 --- /dev/null +++ b/src/axolotl/utils/ctx_managers/__init__.py @@ -0,0 +1,6 @@ +"""Init for context manager submodule""" + +# pylint: disable=unused-import +# flake8: noqa + +from .sequence_parallel import SequenceParallelContextManager diff --git a/src/axolotl/utils/ctx_managers/sequence_parallel.py b/src/axolotl/utils/ctx_managers/sequence_parallel.py new file mode 100644 index 000000000..66044f7f0 --- /dev/null +++ b/src/axolotl/utils/ctx_managers/sequence_parallel.py @@ -0,0 +1,335 @@ +"""Module for Axolotl trainer sequence parallelism manager and utilities""" + +import functools + +import torch +import torch.distributed as dist +from torch import nn +from torch.utils.hooks import RemovableHandle +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import ModelOutput + +from axolotl.monkeypatch.attention.ring_attn.patch import ( + get_ring_attn_group, + update_ring_attn_params, +) +from axolotl.utils.schemas.enums import RingAttnFunc + + +# TODO(djsaunde): implement zigzag, stripe patterns here (and elsewhere) in this +# module. Currently, we just focus on batch ring and varlen llama3 for simplicity. +def apply_sequence_parallelism( + batch: dict[str, torch.Tensor], + local_rank: int, + local_world_size: int, + gradient_accumulation_steps: int, + ring_attn_func: RingAttnFunc, # pylint: disable=unused-argument +) -> tuple[dict[str, torch.Tensor], int, int]: + """ + Apply sequence parallelism slicing to a batch. + + Special handling is implemented for integer logits_to_keep, which indicates + to only keep the last N tokens in the sequence during generation. + + Args: + batch: Batch dictionary (e.g., input_ids, attention_mask, etc.). + local_rank: Local rank in the sequence parallel group. + local_world_size: World size of the sequence parallel group. + gradient_accumulation_steps: Number of steps to accumulate gradients over. + ring_attn_func: Which ring attention function to use. Currently unused, but + related to above TODO. + + Returns: + tuple of: + - Batch dictionary with sliced tensors. + - The original sequence length before padding. + - The number of padding tokens added. + """ + original_seq_len = batch["input_ids"].size(1) + + # Update ring attention params if needed + if batch.get("position_ids") is not None: + update_ring_attn_params(position_ids=batch["position_ids"]) + else: + # If position_ids aren't already in the batch, create them + batch["position_ids"] = torch.arange( + 0, + original_seq_len, + dtype=torch.long, + device=batch["input_ids"].device, + ).expand(batch["input_ids"].size(0), -1) + + if "logits_to_keep" in batch and isinstance(batch["logits_to_keep"], int): + logits_to_keep = batch["logits_to_keep"] + + # Calculate which positions in the full sequence contain the last N tokens + start_position = max(0, original_seq_len - logits_to_keep) + chunk_size = original_seq_len // local_world_size + rank_start = local_rank * chunk_size + rank_end = rank_start + chunk_size + + # Create a boolean mask tensor for this rank's chunk + mask = torch.zeros( + chunk_size, + dtype=torch.bool, + device=batch["input_ids"].device, + ) + + if rank_end > start_position: + # Calculate how many of the last N tokens fall within this rank's range + tokens_in_rank = min(rank_end, original_seq_len) - max( + rank_start, start_position + ) + + # Calculate where these tokens start in the local chunk + local_start_idx = max(0, start_position - rank_start) + + # Set the appropriate positions in the mask to True + mask[local_start_idx : local_start_idx + tokens_in_rank] = True + + # Replace the integer with the boolean mask + batch["logits_to_keep"] = mask + + # Add padding to make sequence length divisible by local_world_size + total_seq_len = original_seq_len + pad_len = 0 + divisor = min(local_world_size, 64) + if total_seq_len % divisor != 0: + pad_len = divisor - (total_seq_len % divisor) + + # Apply padding to all relevant tensors + for key in batch: + if ( + isinstance(batch[key], torch.Tensor) + and batch[key].dim() > 1 + and batch[key].size(1) == total_seq_len + ): + # Create padding tensor + pad_value = -100 if key == "labels" else 0 + padding = torch.full( + (batch[key].size(0), pad_len, *batch[key].shape[2:]), + pad_value, + dtype=batch[key].dtype, + device=batch[key].device, + ) + + # Concatenate padding to the right side of the tensor + batch[key] = torch.cat([batch[key], padding], dim=1) + if key == "logits_to_keep": + # Create padding tensor + padding = torch.ones( + 1, + dtype=batch[key].dtype, + device=batch[key].device, + ) + + # Concatenate padding to the right side of the tensor + batch[key] = torch.cat([batch[key], padding], dim=0) + + # Update the total sequence length after padding + total_seq_len = batch["input_ids"].size(1) + + # Slice batch for sequence parallel + for key in batch: + if not isinstance(batch[key], torch.Tensor) or batch[key].dim() <= 1: + continue + + # Split in sequential fashion and grab this rank's chunk + if batch[key].size(1) == total_seq_len: + batch[key] = ( + batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() + ) + elif key == "logits_to_keep": + batch[key] = ( + batch[key].chunk(local_world_size, dim=0)[local_rank].contiguous() + ) + + # Handle num_items_in_batch + if "num_items_in_batch" in batch: + # Approximation; this needed since num_items_in_batch may be counted across + # all samples in a gradient accumulated batch, not on a per-step basis. + batch["num_items_in_batch"] = ( + batch["labels"] != -100 + ).sum() * gradient_accumulation_steps + + return batch, original_seq_len, pad_len + + +class SequenceParallelContextManager: + """Context manager for sequence parallelism operations. + + This class provides a context that will automatically apply sequence parallelism + during model forward passes using a pre-forward hook, and gather outputs from + across the sequence parallelism group using a post-forward hook. + + Args: + models: List of models to apply sequence parallelism to pre- and post- forward + hooks. + sequence_parallel_degree: Number of processes to split sequences over. + gradient_accumulation_steps: Number of steps to accumulate gradients over. + ring_attn_func: Which ring attention function to use. Currently unused. + """ + + def __init__( + self, + models: list[nn.Module], + sequence_parallel_degree: int, + gradient_accumulation_steps: int, + ring_attn_func: RingAttnFunc, + ): + self.models = models + self.sequence_parallel_degree = sequence_parallel_degree + self.gradient_accumulation_steps = gradient_accumulation_steps + self.ring_attn_func = ring_attn_func + self.process_group = get_ring_attn_group() + + # Initialize sequence parallel group details + self.local_rank = dist.get_rank(self.process_group) + self.local_world_size = dist.get_world_size(self.process_group) + + # Will store hook handles for removal + self.hook_handles: list[RemovableHandle] = [] + + # Store original sequence length and padding information + self.original_seq_len = 0 + self.pad_len = 0 + + # Create a partially applied version of the apply_sequence_parallelism function + self.apply_sequence_parallelism = functools.partial( + apply_sequence_parallelism, + local_rank=self.local_rank, + local_world_size=self.local_world_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + ring_attn_func=self.ring_attn_func, + ) + + def __enter__(self): + # Forward pre-hook to apply sequence parallelism + def sequence_parallel_pre_hook(_, args, kwargs): + # Apply sequence parallelism to kwargs and get original sequence length and padding info + kwargs, self.original_seq_len, self.pad_len = ( + self.apply_sequence_parallelism(batch=kwargs) + ) + + return args, kwargs + + # Forward post-hook to gather outputs + def sequence_parallel_post_hook(_, __, output: ModelOutput) -> ModelOutput: + # Gather the sharded outputs + output = self.gather_outputs(output) + + # Remove padding if it was added + if self.pad_len > 0: + for key, value in output.items(): + if isinstance(value, torch.Tensor) and value.dim() > 1: + if value.size(1) == self.original_seq_len + self.pad_len: + # Slice to remove padding + output[key] = value[:, : self.original_seq_len].contiguous() + + return output + + # Register both hooks + for model in self.models: + self.hook_handles.append( + model.register_forward_pre_hook( + sequence_parallel_pre_hook, with_kwargs=True + ) + ) + self.hook_handles.append( + model.register_forward_hook(sequence_parallel_post_hook) + ) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + + def gather_outputs(self, output: CausalLMOutputWithPast) -> CausalLMOutputWithPast: + """Gather sharded outputs from all ranks and reconstruct the full tensor.""" + for key, value in output.items(): + if isinstance(value, torch.Tensor) and value.dim() > 1: + output[key] = AllGatherWithGrad.apply(value, self.process_group) + + return output + + +class AllGatherWithGrad(torch.autograd.Function): + """Custom autograd function for all-gather to preserve gradients.""" + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input_tensor: torch.Tensor, + group: dist.ProcessGroup, + ) -> torch.Tensor: + """ + Forward pass of all-gather of data with sequence dimension. + + Args: + ctx: `torch.autograd` function context. + input_tensor: Tensor from model output with sequence dimension. + group: `torch.distributed` process group. + + Returns: + Tensor from gathering the `input_tensor` from across the process group and + concatenating along the sequence dimension. + """ + ctx.group = group + ctx.rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + + # Gather shape metadata + local_shape = torch.tensor(list(input_tensor.shape), device=input_tensor.device) + all_shapes = [torch.zeros_like(local_shape) for _ in range(world_size)] + dist.all_gather(all_shapes, local_shape, group=group) + + # Store sequence lengths for backward pass + seq_lens = [int(shape[1].item()) for shape in all_shapes] + ctx.seq_lens = seq_lens + + # Perform all_gather operation + gathered = [ + torch.zeros( + tuple(shape.tolist()), + dtype=input_tensor.dtype, + device=input_tensor.device, + ) + for shape in all_shapes + ] + dist.all_gather(gathered, input_tensor, group=group) + + # Concatenate tensors along sequence dimension + result = torch.cat(gathered, dim=1) + + return result + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor, None]: + """ + Backward pass for all-gather operation. + + Extracts the gradient slice corresponding to this rank's original input + from the full gradient tensor. + + Args: + ctx: `torch.autograd` function context. + grad_output: Gradient from subsequent layers with respect to the + concatenated output tensor. + + Returns: + Tuple containing the gradient slice for this rank's input tensor and `None` + for the process group parameter which doesn't require gradients. + """ + rank = ctx.rank + seq_lens = ctx.seq_lens + + # Extract gradient for this rank's chunk + offset = sum(seq_lens[:rank]) + grad_slice = grad_output[:, offset : offset + seq_lens[rank]].contiguous() + + return grad_slice, None diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 135de61a3..eaa834822 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -18,8 +18,9 @@ from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.models import load_tokenizer +from axolotl.utils.schemas.enums import RLType -LOG = logging.getLogger("axolotl") +LOG = logging.getLogger(__name__) def _get_path(ds_hash, cfg): @@ -80,7 +81,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): def drop_long_rl_seq( sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name ): - if rl in ("dpo", "ipo", "orpo", "simpo"): + if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): if not ( sample.get("prompt") and sample.get("chosen") and sample.get("rejected") ): @@ -100,7 +101,7 @@ def drop_long_rl_seq( len_prompt + len_rejected ) <= sequence_len - if rl == "kto": + if rl is RLType.KTO: if not (sample.get("prompt") and sample.get("completion")): raise ValueError("Prompt and completion keys are required for KTO datasets") @@ -114,7 +115,7 @@ def drop_long_rl_seq( return (len_prompt + len_completion) <= sequence_len - if rl == "grpo": + if rl is RLType.GRPO: return True raise ValueError("Unknown RL type") @@ -137,9 +138,9 @@ def load_prepare_preference_datasets(cfg): if _type: if isinstance(_type, DictDefault): _type = "user_defined.default" - if _cfg.rl == "orpo": + if _cfg.rl is RLType.ORPO: ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i) - elif _cfg.rl == "kto": + elif _cfg.rl is RLType.KTO: ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) else: ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) @@ -150,7 +151,7 @@ def load_prepare_preference_datasets(cfg): split_datasets[i] = map_dataset( cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs ) - elif _cfg.rl == "kto": + elif _cfg.rl is RLType.KTO: ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) map_kwargs = {} if isinstance(ds_transform_fn, tuple): @@ -185,7 +186,7 @@ def load_prepare_preference_datasets(cfg): ) combined_datasets = concatenate_datasets(split_datasets) - combined_datasets = combined_datasets.shuffle(seed=cfg.seed) + combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42) return combined_datasets @@ -205,6 +206,8 @@ def load_prepare_preference_datasets(cfg): eval_dataset = load_split(cfg.test_datasets, cfg) if not eval_dataset: if cfg.val_set_size: + seed = cfg.seed if cfg.seed is not None else 42 + # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( train_dataset._fingerprint # pylint: disable=protected-access @@ -213,7 +216,7 @@ def load_prepare_preference_datasets(cfg): + "|" + "train" + "|" - + str(cfg.seed or 42) + + str(seed) ) to_hash_test = ( train_dataset._fingerprint # pylint: disable=protected-access @@ -222,13 +225,13 @@ def load_prepare_preference_datasets(cfg): + "|" + "test" + "|" - + str(cfg.seed or 42) + + str(seed) ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) ds_w_test_split = train_dataset.train_test_split( test_size=cfg.val_set_size, - seed=cfg.seed, + seed=seed, shuffle=False, train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 12f0701f0..5fa0cb60d 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -148,7 +148,7 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): ds_wrapper_partial, max_tokens=cfg.sequence_len, batch_size=cfg.micro_batch_size, - seed=cfg.seed or 42, + seed=cfg.seed if cfg.seed is not None else 42, buffer_size=cfg.pretrain_multipack_buffer_size or 10_000, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 @@ -416,6 +416,8 @@ def load_prepare_datasets( ) if split == "train" and val_set_size: + seed = cfg.seed if cfg.seed is not None else 42 + # ensure we end up with the same fingerprint by doing rank0 first and being able to cache to_hash_train = ( dataset._fingerprint # pylint: disable=protected-access @@ -424,7 +426,7 @@ def load_prepare_datasets( + "|" + "train" + "|" - + str(cfg.seed or 42) + + str(seed) ) to_hash_test = ( dataset._fingerprint # pylint: disable=protected-access @@ -433,7 +435,7 @@ def load_prepare_datasets( + "|" + "test" + "|" - + str(cfg.seed or 42) + + str(seed) ) train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) @@ -442,7 +444,7 @@ def load_prepare_datasets( dataset = dataset.train_test_split( test_size=val_set_size, shuffle=False, - seed=cfg.seed or 42, + seed=seed, train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint, ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6aa4dd162..dff6d854b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -73,6 +73,7 @@ from axolotl.utils.distributed import ( from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant +from axolotl.utils.schemas.enums import RLType LOG = logging.getLogger(__name__) PLUGIN_MANAGER = PluginManager.get_instance() @@ -1372,7 +1373,7 @@ class ModelLoader: # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config if ( self.cfg.adapter - and self.cfg.rl in ["dpo", "ipo", "kto"] + and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] and not self.cfg.merge_lora ): _, lora_config = load_lora( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index cd9891e04..34f084f10 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -27,7 +27,7 @@ from axolotl.utils.schemas.datasets import ( StepwiseSupervisedDataset, ) from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters -from axolotl.utils.schemas.enums import ChatTemplate, RLType +from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.integrations import ( CometConfig, GradioConfig, @@ -260,7 +260,7 @@ class AxolotlInputConfig( sequence_parallel_degree: int | None = None heads_k_stride: int | None = None - ring_attn_func: str | None = None + ring_attn_func: RingAttnFunc | None = None special_tokens: SpecialTokensConfig | None = None tokens: list[str] | None = None @@ -782,7 +782,7 @@ class AxolotlInputConfig( @model_validator(mode="after") def check_simpo_warmup(self): - if self.rl == "simpo" and self.warmup_ratio: + if self.rl is RLType.SIMPO and self.warmup_ratio: raise ValueError( "warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead" ) @@ -1161,6 +1161,18 @@ class AxolotlInputConfig( raise ValueError("PEFT + GRPO + Liger is not yet supported") return data + @model_validator(mode="before") + @classmethod + def check_grpo_liger_sequence_parallel(cls, data): + if ( + data.get("rl") == "grpo" + and data.get("trl", {}) + and data.get("trl").get("use_liger_loss") + and data.get("sequence_parallel_degree", 1) > 1 + ): + raise ValueError("GRPO + SP + Liger not currently supported") + return data + @model_validator(mode="after") def check_sequence_parallel_degree(self): if not self.sequence_parallel_degree: @@ -1173,7 +1185,7 @@ class AxolotlInputConfig( if self.sample_packing and self.micro_batch_size > 1: raise ValueError( - "micro_batch_size must be set to 1 when sample_packing is enabled" + "micro_batch_size must be set to 1 when sample_packing is enabled " "due to a `ring-flash-attn` requirement" ) @@ -1205,16 +1217,8 @@ class AxolotlInputConfig( if getattr(self, "sequence_parallel_degree", 1) == 1: return self - from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc - if self.ring_attn_func is not None: - valid_funcs = list(RingAttnFunc) - if self.ring_attn_func in valid_funcs: - self.ring_attn_func = RingAttnFunc(self.ring_attn_func) - else: - raise ValueError( - f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}" - ) + self.ring_attn_func = RingAttnFunc(self.ring_attn_func) else: # Default ring attention function selection sample_packing = getattr(self, "sample_packing", False) diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index fe5cf62ba..ff8471dfd 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -6,12 +6,12 @@ from enum import Enum class RLType(str, Enum): """RL trainer type configuration subset""" - dpo = "dpo" # pylint: disable=invalid-name - grpo = "grpo" # pylint: disable=invalid-name - ipo = "ipo" # pylint: disable=invalid-name - orpo = "orpo" # pylint: disable=invalid-name - kto = "kto" # pylint: disable=invalid-name - simpo = "simpo" # pylint: disable=invalid-name + DPO = "dpo" # pylint: disable=invalid-name + GRPO = "grpo" # pylint: disable=invalid-name + IPO = "ipo" # pylint: disable=invalid-name + ORPO = "orpo" # pylint: disable=invalid-name + KTO = "kto" # pylint: disable=invalid-name + SIMPO = "simpo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): @@ -55,3 +55,14 @@ class CustomSupportedOptimizers(str, Enum): adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name came_pytorch = "came_pytorch" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name + + +class RingAttnFunc(str, Enum): + """Enum class for supported `ring-flash-attn` implementations""" + + # VARLEN_RING = "varlen_ring" + # VARLEN_ZIGZAG = "varlen_zigzag" + VARLEN_LLAMA3 = "varlen_llama3" + BATCH_RING = "batch_ring" + # BATCH_ZIGZAG = "batch_zigzag" + # BATCH_STRIPE = "batch_stripe" diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index 1667408f4..1170f5eee 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -25,6 +25,7 @@ class TestSequenceParallelism: micro_batch_size=1, pad_to_sequence_len=True, ring_attn_func=None, + threshold=2.0, ): """Helper method to run sequence parallel tests with different configurations""" cfg = DictDefault( @@ -93,22 +94,22 @@ class TestSequenceParallelism: ) check_tensorboard( - temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high" + temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high" ) @pytest.mark.parametrize( - "sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func", + "sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func, threshold", [ - (True, 1, True, None), # defaults to varlen_llama3 ring_attn_func - (False, 2, True, None), # defaults to batch_ring ring_attn_func - (False, 2, True, "batch_zigzag"), - # (False, 2, False), # not yet working + (True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func + (False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func + # (False, 2, True, "batch_zigzag", 2.5), + (False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func ], ids=[ "sample_packing, varlen_llama3 ring_attn_func", + "no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func", + # "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func", "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func", - "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func", - # "no sample_packing, pad_to_sequence_len", # not yet working ], ) def test_sequence_parallel_training( @@ -118,6 +119,7 @@ class TestSequenceParallelism: micro_batch_size, pad_to_sequence_len, ring_attn_func, + threshold, ): """Test sequence parallel training with different configurations""" self._run_sequence_parallel_test( @@ -126,4 +128,5 @@ class TestSequenceParallelism: micro_batch_size=micro_batch_size, pad_to_sequence_len=pad_to_sequence_len, ring_attn_func=ring_attn_func, + threshold=threshold, ) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 046c482e3..8efe62940 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -10,14 +10,15 @@ import pytest import torch from accelerate.state import PartialState -from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism from axolotl.monkeypatch.attention.ring_attn import ( - RingAttnFunc, get_ring_attn_group, register_ring_attn, set_ring_attn_group, ) +from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism from axolotl.utils.dict import DictDefault +from axolotl.utils.schemas.enums import RingAttnFunc +from axolotl.utils.schemas.trl import TRLConfig @pytest.fixture @@ -62,12 +63,14 @@ def sequence_parallel_batch(): input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len) attention_mask = torch.ones(batch_size, seq_len) position_ids = torch.arange(seq_len).expand(batch_size, seq_len) + labels = input_ids.clone() # Create test batch batch = { "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, + "labels": labels, } return batch @@ -179,12 +182,44 @@ class TestConfigValidation: False, "micro_batch_size must be set to 1", ), + # Valid: Basic GRPO config + ( + { + "sequence_parallel_degree": 2, + "flash_attention": True, + "micro_batch_size": 2, + "trl": {"use_liger_loss": True}, + }, + { + "sequence_parallel_degree": 2, + "flash_attention": True, + "micro_batch_size": 2, + "trl": TRLConfig(use_liger_loss=True), + }, + True, + "GRPO + SP + Liger not currently supported", + ), + # Invalid: GRPO config with Liger loss + ( + { + "rl": "grpo", + "sequence_parallel_degree": 2, + "flash_attention": True, + "micro_batch_size": 2, + "trl": {"use_liger_loss": True}, + }, + None, + False, + "GRPO + SP + Liger not currently supported", + ), ], ids=[ "valid_config", "default_sp_degree", "without_flash_attention", "sample_packing_with_large_batch", + "valid_grpo", + "grpo_with_liger_loss", ], ) def test_sequence_parallel_config_validation( @@ -256,7 +291,7 @@ class TestConfigValidation: AxolotlInputConfig(**cfg) # Verify error message - assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value) + assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value) class TestApplySequenceParallelism: @@ -290,10 +325,11 @@ class TestApplySequenceParallelism: def test_world_size_one(self, sequence_parallel_batch): """Test that function returns original batch when world size is 1.""" - result = apply_sequence_parallelism( + result, _, _ = apply_sequence_parallelism( batch=sequence_parallel_batch, local_rank=0, local_world_size=1, + gradient_accumulation_steps=1, ring_attn_func=RingAttnFunc.BATCH_RING, ) @@ -305,10 +341,11 @@ class TestApplySequenceParallelism: batch = sequence_parallel_batch seq_len = batch["input_ids"].size(1) - result = apply_sequence_parallelism( + result, _, _ = apply_sequence_parallelism( batch=batch, local_rank=0, local_world_size=2, + gradient_accumulation_steps=1, ring_attn_func=RingAttnFunc.BATCH_RING, ) @@ -328,57 +365,59 @@ class TestApplySequenceParallelism: seq_len = batch["input_ids"].size(1) original_input_ids = batch["input_ids"].clone() - result = apply_sequence_parallelism( + result, _, _ = apply_sequence_parallelism( batch=batch, local_rank=1, local_world_size=2, + gradient_accumulation_steps=1, ring_attn_func=RingAttnFunc.BATCH_RING, ) # Verify content: rank 1 should get the second half of the sequence assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :]) - def test_batch_zigzag(self, sequence_parallel_batch): - """Test BATCH_ZIGZAG sharding pattern.""" - batch = sequence_parallel_batch - original_input_ids = batch["input_ids"].clone() - seq_len = batch["input_ids"].size(1) + # TODO(djsaunde): add back once implemented. + # def test_batch_zigzag(self, sequence_parallel_batch): + # """Test BATCH_ZIGZAG sharding pattern.""" + # batch = sequence_parallel_batch + # original_input_ids = batch["input_ids"].clone() + # seq_len = batch["input_ids"].size(1) - # Test rank 0 - result_rank0 = apply_sequence_parallelism( - batch={k: v.clone() for k, v in batch.items()}, - local_rank=0, - local_world_size=2, - ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, - ) + # # Test rank 0 + # result_rank0 = apply_sequence_parallelism( + # batch={k: v.clone() for k, v in batch.items()}, + # local_rank=0, + # local_world_size=2, + # ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, + # ) - # Test rank 1 - result_rank1 = apply_sequence_parallelism( - batch={k: v.clone() for k, v in batch.items()}, - local_rank=1, - local_world_size=2, - ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, - ) + # # Test rank 1 + # result_rank1 = apply_sequence_parallelism( + # batch={k: v.clone() for k, v in batch.items()}, + # local_rank=1, + # local_world_size=2, + # ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, + # ) - # Checks for both ranks - assert result_rank0["input_ids"].shape[1] == seq_len // 2 - assert result_rank1["input_ids"].shape[1] == seq_len // 2 + # # Checks for both ranks + # assert result_rank0["input_ids"].shape[1] == seq_len // 2 + # assert result_rank1["input_ids"].shape[1] == seq_len // 2 - # For a 2-rank system with 8 tokens, check specific zigzag pattern - # Rank 0 should get chunks [0, 1] and [6, 7] - # Rank 1 should get chunks [2, 3] and [4, 5] - if seq_len == 8: - # Create expected tensors for comparison - rank0_expected = torch.cat( - [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1 - ) + # # For a 2-rank system with 8 tokens, check specific zigzag pattern + # # Rank 0 should get chunks [0, 1] and [6, 7] + # # Rank 1 should get chunks [2, 3] and [4, 5] + # if seq_len == 8: + # # Create expected tensors for comparison + # rank0_expected = torch.cat( + # [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1 + # ) - rank1_expected = torch.cat( - [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1 - ) + # rank1_expected = torch.cat( + # [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1 + # ) - assert torch.equal(result_rank0["input_ids"], rank0_expected) - assert torch.equal(result_rank1["input_ids"], rank1_expected) + # assert torch.equal(result_rank0["input_ids"], rank0_expected) + # assert torch.equal(result_rank1["input_ids"], rank1_expected) def test_partial_application(self, sequence_parallel_batch): """Test that we can create a partially applied version of the function.""" @@ -390,11 +429,12 @@ class TestApplySequenceParallelism: apply_sequence_parallelism, local_rank=0, local_world_size=2, + gradient_accumulation_steps=1, ring_attn_func=RingAttnFunc.BATCH_RING, ) # Use the partially applied function - result = rank0_ring_parallel(batch=batch) + result, _, _ = rank0_ring_parallel(batch=batch) # Verify it works as expected assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 @@ -412,13 +452,15 @@ class TestApplySequenceParallelism: original_input_ids = batch["input_ids"].clone() # This should run without error even though position_ids is missing - result = apply_sequence_parallelism( + result, _, _ = apply_sequence_parallelism( batch=batch, local_rank=0, local_world_size=2, + gradient_accumulation_steps=1, ring_attn_func=RingAttnFunc.BATCH_RING, ) # Verification should pass - assert "position_ids" not in result + assert "position_ids" in result + assert result["input_ids"].shape[1] == result["position_ids"].shape[1] assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 From 7fa1089cea4e81d2724dee23dae5ad23dfb10399 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 13 May 2025 08:30:58 -0400 Subject: [PATCH 6/7] Atropos support (#2666) [skip ci] * allow peft+liger+grpo and custom vllm serve for atropos support * set trainer class for RL --- src/axolotl/cli/args.py | 6 ++++++ src/axolotl/cli/vllm_serve.py | 4 +++- src/axolotl/core/trainer_builder.py | 4 ++++ src/axolotl/utils/schemas/config.py | 24 ++++++++++++------------ 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 83febd7f4..088e337e4 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -82,6 +82,12 @@ class VllmServeCliArgs: "hardware support this feature." }, ) + serve_module: Optional[str] = field( + default=None, + metadata={ + "help": "Module to serve. If not set, the default module will be used." + }, + ) @dataclass diff --git a/src/axolotl/cli/vllm_serve.py b/src/axolotl/cli/vllm_serve.py index 552f33e9e..d3c4ad68d 100644 --- a/src/axolotl/cli/vllm_serve.py +++ b/src/axolotl/cli/vllm_serve.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Union from trl.scripts.vllm_serve import ScriptArguments -from trl.scripts.vllm_serve import main as vllm_serve_main from axolotl.cli.config import load_cfg @@ -28,6 +27,9 @@ def do_vllm_serve( cfg = load_cfg(config) model = cfg.base_model + serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve") + vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main") + tensor_parallel_size = ( cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 99ab397c7..25d327dcd 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1197,6 +1197,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): trainer_kwargs["tokenizer"] = self.tokenizer diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 34f084f10..25c802959 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1149,18 +1149,18 @@ class AxolotlInputConfig( return data - @model_validator(mode="before") - @classmethod - def check_grpo_peft_liger(cls, data): - if ( - data.get("rl") == "grpo" - and data.get("trl", {}) - and data.get("trl").get("use_liger_loss") - and data.get("adapter") - ): - raise ValueError("PEFT + GRPO + Liger is not yet supported") - return data - + # @model_validator(mode="before") + # @classmethod + # def check_grpo_peft_liger(cls, data): + # if ( + # data.get("rl") == "grpo" + # and data.get("trl", {}) + # and data.get("trl").get("use_liger_loss") + # and data.get("adapter") + # ): + # raise ValueError("PEFT + GRPO + Liger is not yet supported") + # return data + # @model_validator(mode="before") @classmethod def check_grpo_liger_sequence_parallel(cls, data): From c0a0c7534cc3a9a2b4c0ad6235fd5756f35ec17b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 13 May 2025 16:39:39 -0400 Subject: [PATCH 7/7] Activation checkpointing with offloading to disk with prefetch (#2663) * offload activations to disk instead of CPU RAM * add prefetch * Disco :dance: * include offload_disk in e2e test for AC * document and make sure to cleanup * fix annotation to match docs * fix docs build * address PR feedback --- _quarto.yml | 3 +- docs/config.qmd | 2 +- .../utils/gradient_checkpointing/__init__.py | 30 +- .../{unsloth.py => offload_cpu.py} | 4 +- .../gradient_checkpointing/offload_disk.py | 531 ++++++++++++++++++ src/axolotl/utils/models.py | 9 +- src/axolotl/utils/schemas/config.py | 2 +- .../patched/test_activation_checkpointing.py | 7 +- 8 files changed, 577 insertions(+), 11 deletions(-) rename src/axolotl/utils/gradient_checkpointing/{unsloth.py => offload_cpu.py} (95%) create mode 100644 src/axolotl/utils/gradient_checkpointing/offload_disk.py diff --git a/_quarto.yml b/_quarto.yml index 463f76b34..dc5071838 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -139,7 +139,8 @@ quartodoc: - utils.optimizers.adopt - utils.data.pretraining - utils.data.sft - - utils.gradient_checkpointing.unsloth + - utils.gradient_checkpointing.offload_cpu + - utils.gradient_checkpointing.offload_disk - title: Schemas desc: Pydantic data models for Axolotl config contents: diff --git a/docs/config.qmd b/docs/config.qmd index eba9f4881..10e5a5895 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -539,7 +539,7 @@ train_on_inputs: false # Note that training loss may have an oscillating pattern with this enabled. group_by_length: false -# Whether to use gradient checkpointing. Available options are: true, false, "offload". +# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk". # https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing gradient_checkpointing: false # additional kwargs to pass to the trainer for gradient checkpointing diff --git a/src/axolotl/utils/gradient_checkpointing/__init__.py b/src/axolotl/utils/gradient_checkpointing/__init__.py index f84f76d80..ae0c559e9 100644 --- a/src/axolotl/utils/gradient_checkpointing/__init__.py +++ b/src/axolotl/utils/gradient_checkpointing/__init__.py @@ -5,8 +5,11 @@ from functools import partial from packaging import version -from axolotl.utils.gradient_checkpointing.unsloth import ( - Unsloth_Offloaded_Gradient_Checkpointer, +from axolotl.utils.gradient_checkpointing.offload_cpu import ( + CPU_Offloaded_Gradient_Checkpointer, +) +from axolotl.utils.gradient_checkpointing.offload_disk import ( + Disco, ) transformers_version = version.parse(importlib.metadata.version("transformers")) @@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper( decoder_layer, *args, use_reentrant=None ): # pylint: disable=unused-argument if uses_gc_layers(decoder_layer): - return Unsloth_Offloaded_Gradient_Checkpointer.apply( + return CPU_Offloaded_Gradient_Checkpointer.apply( decoder_layer, *args, ) - return Unsloth_Offloaded_Gradient_Checkpointer.apply( + return CPU_Offloaded_Gradient_Checkpointer.apply( + ( + decoder_layer.func.__self__ + if isinstance(decoder_layer, partial) + else decoder_layer.__self__ + ), + *args, + ) + + +def hf_grad_checkpoint_disk_offload_wrapper( + decoder_layer, *args, use_reentrant=None +): # pylint: disable=unused-argument + if uses_gc_layers(decoder_layer): + return Disco.apply( + decoder_layer, + *args, + ) + + return Disco.apply( ( decoder_layer.func.__self__ if isinstance(decoder_layer, partial) diff --git a/src/axolotl/utils/gradient_checkpointing/unsloth.py b/src/axolotl/utils/gradient_checkpointing/offload_cpu.py similarity index 95% rename from src/axolotl/utils/gradient_checkpointing/unsloth.py rename to src/axolotl/utils/gradient_checkpointing/offload_cpu.py index 7a14614b1..bbb5ad40d 100644 --- a/src/axolotl/utils/gradient_checkpointing/unsloth.py +++ b/src/axolotl/utils/gradient_checkpointing/offload_cpu.py @@ -1,4 +1,4 @@ -"""Unsloth checkpointing""" +"""CPU offloaded checkpointing""" # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # @@ -26,7 +26,7 @@ else: torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") -class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name +class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name torch.autograd.Function ): """ diff --git a/src/axolotl/utils/gradient_checkpointing/offload_disk.py b/src/axolotl/utils/gradient_checkpointing/offload_disk.py new file mode 100644 index 000000000..90e70f504 --- /dev/null +++ b/src/axolotl/utils/gradient_checkpointing/offload_disk.py @@ -0,0 +1,531 @@ +""" +DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching +""" + +# Copyright 2025 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import concurrent.futures +import logging +import os +import queue +import shutil +import tempfile +import threading +import time +import uuid +from collections import deque +from concurrent.futures import Future +from typing import Dict + +import torch + +torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") +torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") + +# Setup logger +logger = logging.getLogger(__name__) + + +class DiskOffloadManager: + """ + Manages offloaded tensors and handles prefetching in a separate thread. + Includes synchronization to prevent race conditions. + """ + + def __init__( + self, + prefetch_size: int = 3, + prefetch_to_gpu: bool = True, + save_workers: int = 4, + ): + """ + Args: + prefetch_size: Maximum number of tensors to prefetch in the background. + prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory. + save_workers: Maximum number of concurrent save operations. + """ + self.temp_dir = tempfile.mkdtemp(prefix="disco_") + + # Track tensor paths and their status + self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO) + self.file_locks: Dict[str, threading.Lock] = ( + {} + ) # Maps file_path -> threading.Lock() + # Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted") + self.file_status: Dict[str, str] = {} + + self.max_prefetch = prefetch_size + self.prefetch_to_gpu = prefetch_to_gpu + + # Thread synchronization + self.manager_lock = threading.RLock() # Used for thread-safe operations + + # Prefetch queue and cache + self.prefetch_queue: queue.Queue = queue.Queue() + self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor + + # Save queue and thread pool + self.save_queue: queue.Queue = queue.Queue() + self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers) + self.save_futures: Dict[str, Future] = {} + self.save_semaphore = threading.Semaphore( + save_workers * 2 + ) # Limit concurrent save operations + + # Start prefetch worker thread + self.stop_event = threading.Event() + # start multiple threads for prefetching + self.prefetch_worker_count = 2 + self.prefetch_workers = [] + for _ in range(self.prefetch_worker_count): + worker = threading.Thread(target=self._prefetch_worker, daemon=True) + worker.start() + self.prefetch_workers.append(worker) + + # Start save worker thread + self.save_worker = threading.Thread(target=self._save_worker, daemon=True) + self.save_worker.start() + self.idx = 0 + + atexit.register(self.cleanup) + + def _save_worker(self): + """Background thread that processes the save queue""" + while not self.stop_event.is_set(): + try: + save_item = self.save_queue.get(timeout=0.5) + if save_item is None: + continue + + tensor, file_path = save_item + + # Submit the save task to the thread pool + future = self.save_pool.submit( + self._save_tensor_to_disk, tensor, file_path + ) + with self.manager_lock: + self.save_futures[file_path] = future + + self.save_queue.task_done() + + except queue.Empty: + time.sleep(0.01) # Small sleep to prevent CPU spinning + continue + + def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str): + """Actually save the tensor to disk""" + try: + # Save tensor to disk + cpu_tensor = tensor.detach().cpu() + torch.save(cpu_tensor, file_path) + del cpu_tensor + + with self.manager_lock: + # Mark file as ready + self.file_status[file_path] = "ready" + + # Release semaphore + self.save_semaphore.release() + + return True + except FileNotFoundError as e: + logger.error(f"Error saving tensor to {file_path}: {e}") + with self.manager_lock: + self.file_status[file_path] = "error" + + # Release semaphore + self.save_semaphore.release() + + return False + + def _prefetch_worker(self): + """Background thread that loads tensors from disk ahead of time""" + while not self.stop_event.is_set(): + try: + file_path = self.prefetch_queue.get(timeout=0.5) + if file_path is None: + continue + + # Check if file is available and not already in cache + with self.manager_lock: + if ( + file_path not in self.file_status + or self.file_status[file_path] == "deleted" + ): + self.prefetch_queue.task_done() + if file_path in self.prefetch_cache: + self.prefetch_queue.task_done() + continue + + # If file is still being saved, wait for it + if ( + self.file_status[file_path] == "saving" + and file_path in self.save_futures + ): + # Re-queue this prefetch request with a little delay + self.prefetch_queue.task_done() + time.sleep(0.1) + self.prefetch_queue.put(file_path) + continue + + # Mark file as being prefetched + self.file_status[file_path] = "prefetching" + + # Load tensor from disk and store in cache + try: + if os.path.exists(file_path): + if self.prefetch_to_gpu: + tensor = torch.load( + file_path, + map_location=torch.device("cuda"), + weights_only=True, + ) + else: + tensor = torch.load(file_path, weights_only=True) + + with self.manager_lock: + self.prefetch_cache[file_path] = tensor + self.file_status[file_path] = "ready" + else: + with self.manager_lock: + if self.file_status.get(file_path) != "deleted": + logger.warning( + f"Prefetch error: File not found {file_path}" + ) + self.file_status[file_path] = "missing" + + except FileNotFoundError as e: + with self.manager_lock: + if self.file_status.get(file_path) != "deleted": + logger.warning(f"Prefetch error for {file_path}: {e}") + self.file_status[file_path] = "error" + + self.prefetch_queue.task_done() + + except queue.Empty: + time.sleep(0.01) # Small sleep to prevent CPU spinning + continue + + def save_tensor(self, tensor: torch.Tensor): + """Save tensor to disk asynchronously and return file path with thread-safe operations""" + # Generate unique file path + self.idx += 1 + file_path: str = os.path.join( + self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt" + ) + + with self.manager_lock: + # Mark file as being saved + self.file_locks[file_path] = threading.Lock() + self.file_status[file_path] = "saving" + # Add to history + self.tensor_paths.append(file_path) + + # Acquire semaphore to limit concurrent save operations + self.save_semaphore.acquire() # pylint: disable=consider-using-with + # Queue tensor for saving in background + self.save_queue.put((tensor.detach(), file_path)) + + return file_path + + def wait_for_save(self, file_path, timeout=None) -> None: + """Wait for a tensor to be saved to disk""" + start_time = time.time() + while timeout is None or time.time() - start_time < timeout: + with self.manager_lock: + if self.file_status.get(file_path) == "ready": + return + if self.file_status.get(file_path) in ["error", "missing", "deleted"]: + return + + if file_path in self.save_futures: + future = self.save_futures[file_path] + if future.done(): + return + + # Small sleep to prevent CPU spinning + time.sleep(0.01) + + # Timeout + logger.warning(f"Timeout waiting for tensor to be saved: {file_path}") + return + + def load_tensor(self, file_path, target_device="cuda"): + """Load tensor from disk or prefetch cache with proper synchronization""" + # Wait for tensor to be saved if it's still in progress + self.wait_for_save(file_path) + + tensor = None + + # Try to get from cache first + with self.manager_lock: + # Check if tensor is already in cache + if file_path in self.prefetch_cache: + tensor = self.prefetch_cache[file_path] + del self.prefetch_cache[file_path] + self.file_status[file_path] = "loaded" + + if tensor is not None: + # Ensure tensor is on correct device + if target_device != "cpu" and tensor.device.type == "cpu": + tensor = tensor.to(target_device, non_blocking=True) + return tensor + + # If not in cache, load directly from disk + try: + if not os.path.exists(file_path): + logger.error(f"File not found for loading: {file_path}") + raise FileNotFoundError(f"File not found: {file_path}") + + tensor = torch.load(file_path, weights_only=True) + + with self.manager_lock: + self.file_status[file_path] = "loaded" + + if target_device != "cpu": + tensor = tensor.to(target_device, non_blocking=True) + + return tensor + + except Exception as e: + logger.error(f"Error loading tensor from {file_path}: {e}") + raise + + def _safe_delete_file(self, file_path): + """Safely delete a file with proper synchronization""" + with self.manager_lock: + # Make sure any save operation is completed + if file_path in self.save_futures: + future = self.save_futures[file_path] + try: + if not future.done(): + future.cancel() + del self.save_futures[file_path] + except FileNotFoundError as e: + logger.warning( + f"Error canceling save operation for {file_path}: {e}" + ) + + # Only delete if file exists and is not being prefetched + status = self.file_status.get(file_path) + if status in ["ready", "loaded", "error", "missing"]: + try: + if os.path.exists(file_path): + os.remove(file_path) + self.file_status[file_path] = "deleted" + return True + except FileNotFoundError as e: + logger.warning(f"Error deleting file {file_path}: {e}") + return False + + def trigger_prefetch(self, n=None): + """Trigger prefetching of the next N tensors with proper synchronization""" + if n is None: + n = self.max_prefetch + + prefetch_paths = [] + with self.manager_lock: + # Find files that are ready to be prefetched (not already in cache or being prefetched) + for path in reversed(self.tensor_paths): + if ( + path not in self.prefetch_cache + and self.file_status.get(path) == "ready" + ): + prefetch_paths.append(path) + if len(prefetch_paths) >= n: + break + + # Queue files for prefetching + for path in prefetch_paths: + self.prefetch_queue.put(path) + + def cleanup_tensor(self, file_path: str): + """Clean up a specific tensor file after it's been used""" + with self.manager_lock: + if file_path in self.tensor_paths: + self.tensor_paths.remove(file_path) + + # Remove from prefetch cache if present + if file_path in self.prefetch_cache: + del self.prefetch_cache[file_path] + + # Remove from save futures if present + if file_path in self.save_futures: + future = self.save_futures[file_path] + if not future.done(): + future.cancel() + del self.save_futures[file_path] + + # Try to delete the file + self._safe_delete_file(file_path) + + def cleanup(self): + """Clean up all temp files and stop prefetch thread with proper synchronization""" + self.stop_event.set() + + # Cancel all pending save operations + with self.manager_lock: + for _, future in self.save_futures.items(): + if not future.done(): + future.cancel() + self.save_futures.clear() + + # Drain the save queue + while not self.save_queue.empty(): + try: + self.save_queue.get_nowait() + self.save_queue.task_done() + except queue.Empty: + break + + # Shutdown the save pool + self.save_pool.shutdown(wait=False) + + # Join the save worker thread + if self.save_worker.is_alive(): + self.save_worker.join(timeout=2.0) + + # Join the prefetch worker threads + for thread in self.prefetch_workers: + if thread.is_alive(): + thread.join(timeout=2.0) + + # Clear cache and remove all temporary files + with self.manager_lock: + self.prefetch_cache.clear() + paths_to_delete = list(self.tensor_paths) + self.tensor_paths.clear() + + # Delete all temporary files + for path in paths_to_delete: + self._safe_delete_file(path) + + # Remove temp directory + try: + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir, ignore_errors=True) + except FileNotFoundError as e: + logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}") + + +class Disco(torch.autograd.Function): + """ + Disco: DIsk-based Storage and Checkpointing with Optimized prefetching + Advanced disk-based gradient checkpointer with prefetching. + """ + + # Shared manager instance across all checkpointing operations + _manager = None + + @staticmethod + def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4): + """Get or create the offload manager""" + if Disco._manager is None: + Disco._manager = DiskOffloadManager( + prefetch_size=prefetch_size, + prefetch_to_gpu=prefetch_to_gpu, + save_workers=save_workers, + ) + return Disco._manager + + @staticmethod + @torch_cuda_amp_custom_fwd + def forward( + ctx, + forward_function, + hidden_states, + *args, + prefetch_size=1, + prefetch_to_gpu=True, + save_workers=4, + ): + """Forward pass that offloads activations to disk asynchronously""" + # Get or create the manager + manager = Disco.get_instance( + prefetch_size=prefetch_size, + prefetch_to_gpu=prefetch_to_gpu, + save_workers=save_workers, + ) + + # Save tensor to disk asynchronously + file_path = manager.save_tensor(hidden_states) + + # Run forward pass immediately without waiting for save to complete + with torch.no_grad(): + output = forward_function(hidden_states, *args) + + # Store what we need for backward + ctx.save_for_backward(torch.tensor([0])) # Dummy tensor + ctx.file_path = file_path + ctx.forward_function = forward_function + ctx.args = args + + return output + + @staticmethod + @torch_cuda_amp_custom_bwd + def backward(ctx, *grad_outputs): + """Backward pass that loads activations from disk with prefetching""" + # Get the manager + manager = Disco._manager + + # Trigger prefetching for future tensors + # This happens at the start of backward, so should have time to complete + manager.trigger_prefetch() + + # Load hidden states from disk or prefetch cache + file_path = ctx.file_path + try: + # Ensure the file is saved before we try to load it + manager.wait_for_save(file_path) + + hidden_states = manager.load_tensor(file_path) + hidden_states.requires_grad = True + + # Compute gradients + with torch.enable_grad(): + output = ctx.forward_function(hidden_states, *ctx.args) + + # Handle tuple outputs properly + if isinstance(output, tuple): + if len(grad_outputs) == len(output): + torch.autograd.backward(output, grad_outputs) + else: + torch.autograd.backward(output, grad_outputs[0]) + else: + torch.autograd.backward(output, grad_outputs[0]) + + # Clean up the file after we're done with it + manager.cleanup_tensor(file_path) + + return ( + ( + None, # forward_function + hidden_states.grad, # hidden_states grad + ) + + (None,) * len(ctx.args) # for each arg + + ( + None, # prefetch_size + None, # prefetch_to_gpu + None, # save_workers + ) + ) + + except Exception as e: + logger.error(f"Error in backward pass: {e}") + # Clean up the file even on error + manager.cleanup_tensor(file_path) + raise diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index dff6d854b..316fbec8c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -70,7 +70,10 @@ from axolotl.utils.distributed import ( is_local_main_process, is_main_process, ) -from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper +from axolotl.utils.gradient_checkpointing import ( + hf_grad_checkpoint_disk_offload_wrapper, + hf_grad_checkpoint_offload_wrapper, +) from axolotl.utils.lora_embeddings import get_linear_embedding_layers from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant from axolotl.utils.schemas.enums import RLType @@ -620,6 +623,10 @@ class ModelLoader: if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper + if self.cfg.gradient_checkpointing == "offload_disk": + transformers.modeling_utils.checkpoint = ( + hf_grad_checkpoint_disk_offload_wrapper + ) if self.cfg.flash_attention: self.patch_attention() diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 25c802959..8ae9d5c04 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -178,7 +178,7 @@ class AxolotlInputConfig( # torch_dtype: torch.dtype | None - gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field( + gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field( default=False ) gradient_checkpointing_kwargs: dict[str, Any] | None = None diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py index cbabab6fd..45107b871 100644 --- a/tests/e2e/patched/test_activation_checkpointing.py +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -26,10 +26,15 @@ class TestActivationCheckpointing: E2E tests for activation checkpointing """ + @pytest.mark.parametrize( + "gradient_checkpointing", + ["offload", "offload_disk"], + ) def test_activation_checkpointing_offload( self, temp_dir, fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name + gradient_checkpointing, ): # pylint: disable=duplicate-code cfg = DictDefault( @@ -64,7 +69,7 @@ class TestActivationCheckpointing: "sample_packing": True, "bf16": True, "save_safetensors": True, - "gradient_checkpointing": "offload", + "gradient_checkpointing": gradient_checkpointing, } )