Compare commits
9 Commits
v0.9.1
...
release-v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54bbc9bb72 | ||
|
|
5aefebe1fe | ||
|
|
5a36b6ff2d | ||
|
|
224da88fa2 | ||
|
|
493eb8e5c6 | ||
|
|
4780ac7c4d | ||
|
|
cf69de2eb9 | ||
|
|
27e3329273 | ||
|
|
27fec49083 |
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,7 +3,7 @@ name: docker-multigpu-tests-biweekly
|
|||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'tests/e2e/multigpu/*.py'
|
- 'tests/e2e/multigpu/**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
- 'setup.py'
|
- 'setup.py'
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|||||||
264
.github/workflows/tests.yml
vendored
264
.github/workflows/tests.yml
vendored
@@ -44,96 +44,102 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
preload-cache:
|
# preload-cache:
|
||||||
name: Preload HF cache
|
# name: Preload HF cache
|
||||||
runs-on: ubuntu-latest
|
# runs-on: ubuntu-latest
|
||||||
strategy:
|
# strategy:
|
||||||
fail-fast: false
|
# fail-fast: false
|
||||||
matrix:
|
# matrix:
|
||||||
python_version: ["3.11"]
|
# python_version: ["3.11"]
|
||||||
pytorch_version: ["2.6.0"]
|
# pytorch_version: ["2.6.0"]
|
||||||
timeout-minutes: 20
|
# timeout-minutes: 20
|
||||||
|
#
|
||||||
env:
|
# env:
|
||||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||||
|
#
|
||||||
steps:
|
# steps:
|
||||||
- name: Check out repository code
|
# - name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
# uses: actions/checkout@v4
|
||||||
|
#
|
||||||
- name: Restore HF cache
|
# - name: Restore HF cache
|
||||||
id: hf-cache-restore
|
# id: hf-cache-restore
|
||||||
uses: actions/cache/restore@v4
|
# uses: actions/cache/restore@v4
|
||||||
with:
|
# with:
|
||||||
path: |
|
# path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
# /home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
# key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
#
|
||||||
- name: Setup Python
|
# - name: Restore Cache from S3
|
||||||
uses: actions/setup-python@v5
|
# id: hf-cache-restore-s3
|
||||||
with:
|
# run: |
|
||||||
python-version: ${{ matrix.python_version }}
|
# mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
cache: 'pip' # caching pip dependencies
|
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
|
#
|
||||||
- name: upgrade pip
|
# - name: Setup Python
|
||||||
run: |
|
# uses: actions/setup-python@v5
|
||||||
pip3 install --upgrade pip
|
# with:
|
||||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
# python-version: ${{ matrix.python_version }}
|
||||||
|
# cache: 'pip' # caching pip dependencies
|
||||||
- name: Install PyTorch
|
#
|
||||||
run: |
|
# - name: upgrade pip
|
||||||
pip3 install torch==${{ matrix.pytorch_version }}
|
# run: |
|
||||||
|
# pip3 install --upgrade pip
|
||||||
- name: Install dependencies
|
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
run: |
|
#
|
||||||
pip3 show torch
|
# - name: Install PyTorch
|
||||||
pip3 install --no-build-isolation -U -e .
|
# run: |
|
||||||
python scripts/unsloth_install.py | sh
|
# pip3 install torch==${{ matrix.pytorch_version }}
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
#
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
# - name: Install dependencies
|
||||||
|
# run: |
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
# pip3 show torch
|
||||||
run: |
|
# pip3 install --no-build-isolation -U -e .
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
# python scripts/unsloth_install.py | sh
|
||||||
|
# python scripts/cutcrossentropy_install.py | sh
|
||||||
- name: Ensure axolotl CLI was installed
|
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
run: |
|
#
|
||||||
axolotl --help
|
# - name: Make sure PyTorch version wasn't clobbered
|
||||||
|
# run: |
|
||||||
- name: Pre-Download dataset fixture
|
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||||
run: |
|
#
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
# - name: Ensure axolotl CLI was installed
|
||||||
|
# run: |
|
||||||
- name: Run tests
|
# axolotl --help
|
||||||
run: |
|
#
|
||||||
pytest -v tests/conftest.py
|
# - name: Pre-Download dataset fixture
|
||||||
|
# run: |
|
||||||
- name: Upload coverage to Codecov
|
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
uses: codecov/codecov-action@v5
|
#
|
||||||
with:
|
# - name: Run tests
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
# run: |
|
||||||
files: ./coverage.xml
|
# pytest -v tests/conftest.py
|
||||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
#
|
||||||
fail_ci_if_error: false
|
# - name: Upload coverage to Codecov
|
||||||
|
# uses: codecov/codecov-action@v5
|
||||||
- name: cleanup pip cache
|
# with:
|
||||||
run: |
|
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
# files: ./coverage.xml
|
||||||
|
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||||
- name: Save HF cache
|
# fail_ci_if_error: false
|
||||||
id: hf-cache
|
#
|
||||||
uses: actions/cache/save@v4
|
# - name: cleanup pip cache
|
||||||
with:
|
# run: |
|
||||||
path: |
|
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
#
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
# - name: Save HF cache
|
||||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
# id: hf-cache
|
||||||
|
# uses: actions/cache/save@v4
|
||||||
|
# with:
|
||||||
|
# path: |
|
||||||
|
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||||
|
# /home/runner/.cache/huggingface/hub/models--*
|
||||||
|
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [preload-cache]
|
# needs: [preload-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -145,14 +151,20 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore HF cache
|
# - name: Restore HF cache
|
||||||
id: hf-cache-restore
|
# id: hf-cache-restore
|
||||||
uses: actions/cache/restore@v4
|
# uses: actions/cache/restore@v4
|
||||||
with:
|
# with:
|
||||||
path: |
|
# path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
# /home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
# key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
|
- name: Restore Cache from S3
|
||||||
|
id: hf-cache-restore-s3
|
||||||
|
run: |
|
||||||
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -210,7 +222,7 @@ jobs:
|
|||||||
pytest-sdist:
|
pytest-sdist:
|
||||||
name: PyTest from Source Dist
|
name: PyTest from Source Dist
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [preload-cache]
|
# needs: [preload-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
@@ -222,14 +234,20 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore HF cache
|
# - name: Restore HF cache
|
||||||
id: hf-cache-restore
|
# id: hf-cache-restore
|
||||||
uses: actions/cache/restore@v4
|
# uses: actions/cache/restore@v4
|
||||||
with:
|
# with:
|
||||||
path: |
|
# path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
# /home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
# /home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
# key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
|
- name: Restore Cache from S3
|
||||||
|
id: hf-cache-restore-s3
|
||||||
|
run: |
|
||||||
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -365,3 +383,43 @@ jobs:
|
|||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.e2e_tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
|
docker-e2e-cleanup:
|
||||||
|
runs-on: [self-hosted, modal]
|
||||||
|
timeout-minutes: 90
|
||||||
|
needs: [docker-e2e-tests]
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras: vllm
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- name: Install Modal
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install modal==0.71.8 jinja2
|
||||||
|
- name: Update env vars
|
||||||
|
run: |
|
||||||
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
|
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||||
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||||
|
- name: Run tests job on Modal
|
||||||
|
run: |
|
||||||
|
modal run cicd.cleanup
|
||||||
|
|||||||
@@ -57,8 +57,10 @@ async def handler(job):
|
|||||||
logger.info("Training Complete.")
|
logger.info("Training Complete.")
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
del os.environ["WANDB_API_KEY"]
|
if "WANDB_API_KEY" in os.environ:
|
||||||
del os.environ["HF_TOKEN"]
|
del os.environ["WANDB_API_KEY"]
|
||||||
|
if "HF_TOKEN" in os.environ:
|
||||||
|
del os.environ["HF_TOKEN"]
|
||||||
|
|
||||||
|
|
||||||
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})
|
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})
|
||||||
|
|||||||
@@ -124,7 +124,8 @@ quartodoc:
|
|||||||
- utils.optimizers.adopt
|
- utils.optimizers.adopt
|
||||||
- utils.data.pretraining
|
- utils.data.pretraining
|
||||||
- utils.data.sft
|
- utils.data.sft
|
||||||
- utils.gradient_checkpointing.unsloth
|
- utils.gradient_checkpointing.offload_cpu
|
||||||
|
- utils.gradient_checkpointing.offload_disk
|
||||||
- title: Schemas
|
- title: Schemas
|
||||||
desc: Pydantic data models for Axolotl config
|
desc: Pydantic data models for Axolotl config
|
||||||
contents:
|
contents:
|
||||||
|
|||||||
0
cicd/__init__.py
Normal file
0
cicd/__init__.py
Normal file
@@ -18,7 +18,7 @@ pytest -v --durations=10 \
|
|||||||
--cov-append
|
--cov-append
|
||||||
|
|
||||||
# Run patched tests excluding lora kernels with coverage append
|
# Run patched tests excluding lora kernels with coverage append
|
||||||
pytest -v --durations=10 \
|
pytest --full-trace -vvv --durations=10 \
|
||||||
--ignore=tests/e2e/patched/lora_kernels \
|
--ignore=tests/e2e/patched/lora_kernels \
|
||||||
/workspace/axolotl/tests/e2e/patched \
|
/workspace/axolotl/tests/e2e/patched \
|
||||||
--cov=axolotl \
|
--cov=axolotl \
|
||||||
|
|||||||
19
cicd/cleanup.py
Normal file
19
cicd/cleanup.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Modal app to run axolotl GPU cleanup"""
|
||||||
|
|
||||||
|
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||||
|
|
||||||
|
|
||||||
|
@app.function(
|
||||||
|
image=cicd_image,
|
||||||
|
timeout=60 * 60,
|
||||||
|
cpu=8.0,
|
||||||
|
memory=131072,
|
||||||
|
volumes=VOLUME_CONFIG,
|
||||||
|
)
|
||||||
|
def cleanup():
|
||||||
|
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
@app.local_entrypoint()
|
||||||
|
def main():
|
||||||
|
cleanup.remote()
|
||||||
6
cicd/cleanup.sh
Executable file
6
cicd/cleanup.sh
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# cleanup old cache files for datasets processing and intermediate mappings
|
||||||
|
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
|
||||||
|
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;
|
||||||
@@ -1,75 +1,12 @@
|
|||||||
"""Modal app to run axolotl GPU tests"""
|
"""Modal app to run axolotl GPU tests"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||||
|
|
||||||
import os
|
|
||||||
import pathlib
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import jinja2
|
|
||||||
import modal
|
|
||||||
from jinja2 import select_autoescape
|
|
||||||
from modal import App, Image
|
|
||||||
|
|
||||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
|
||||||
|
|
||||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
|
||||||
template_env = jinja2.Environment(
|
|
||||||
loader=template_loader, autoescape=select_autoescape()
|
|
||||||
)
|
|
||||||
df_template = template_env.get_template("Dockerfile.jinja")
|
|
||||||
|
|
||||||
df_args = {
|
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
|
||||||
"CUDA": os.environ.get("CUDA", "121"),
|
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
|
||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
|
||||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
|
||||||
}
|
|
||||||
|
|
||||||
dockerfile_contents = df_template.render(**df_args)
|
|
||||||
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
|
||||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
|
||||||
f.write(dockerfile_contents)
|
|
||||||
|
|
||||||
cicd_image = Image.from_dockerfile(
|
|
||||||
pathlib.Path(temp_dir) / "Dockerfile",
|
|
||||||
context_mount=None,
|
|
||||||
force_build=True,
|
|
||||||
gpu="A10G",
|
|
||||||
).env(df_args)
|
|
||||||
|
|
||||||
app = App("Axolotl CI/CD", secrets=[])
|
|
||||||
|
|
||||||
hf_cache_volume = modal.Volume.from_name(
|
|
||||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
|
||||||
)
|
|
||||||
VOLUME_CONFIG = {
|
|
||||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
|
||||||
}
|
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
|
||||||
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
|
||||||
|
|
||||||
|
|
||||||
def run_cmd(cmd: str, run_folder: str):
|
|
||||||
import subprocess # nosec
|
|
||||||
|
|
||||||
# Propagate errors from subprocess.
|
|
||||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
|
||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
|
||||||
|
|
||||||
|
|
||||||
@app.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=90 * 60, # 90 min
|
||||||
cpu=8.0,
|
cpu=8.0,
|
||||||
memory=131072,
|
memory=131072,
|
||||||
volumes=VOLUME_CONFIG,
|
volumes=VOLUME_CONFIG,
|
||||||
|
|||||||
66
cicd/single_gpu.py
Normal file
66
cicd/single_gpu.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""Modal app to run axolotl GPU tests"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
import modal
|
||||||
|
from jinja2 import select_autoescape
|
||||||
|
from modal import App, Image
|
||||||
|
|
||||||
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||||
|
template_env = jinja2.Environment(
|
||||||
|
loader=template_loader, autoescape=select_autoescape()
|
||||||
|
)
|
||||||
|
df_template = template_env.get_template("Dockerfile.jinja")
|
||||||
|
|
||||||
|
df_args = {
|
||||||
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||||
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||||
|
"CUDA": os.environ.get("CUDA", "121"),
|
||||||
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
|
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||||
|
}
|
||||||
|
|
||||||
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||||
|
f.write(dockerfile_contents)
|
||||||
|
|
||||||
|
cicd_image = Image.from_dockerfile(
|
||||||
|
pathlib.Path(temp_dir) / "Dockerfile",
|
||||||
|
context_mount=None,
|
||||||
|
force_build=True,
|
||||||
|
gpu="A10G",
|
||||||
|
).env(df_args)
|
||||||
|
|
||||||
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
hf_cache_volume = modal.Volume.from_name(
|
||||||
|
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||||
|
)
|
||||||
|
VOLUME_CONFIG = {
|
||||||
|
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||||
|
}
|
||||||
|
|
||||||
|
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||||
|
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||||
|
|
||||||
|
|
||||||
|
def run_cmd(cmd: str, run_folder: str):
|
||||||
|
import subprocess # nosec
|
||||||
|
|
||||||
|
# Propagate errors from subprocess.
|
||||||
|
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||||
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
@@ -19,7 +19,7 @@ coverage:
|
|||||||
if_no_uploads: error
|
if_no_uploads: error
|
||||||
if_not_found: success
|
if_not_found: success
|
||||||
if_ci_failed: error
|
if_ci_failed: error
|
||||||
only_pulls: false
|
only_pulls: true
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
patch:
|
patch:
|
||||||
|
|||||||
@@ -505,6 +505,7 @@ save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of eac
|
|||||||
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
||||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
|
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
|
||||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||||
# if both are set, num_epochs will not be guaranteed.
|
# if both are set, num_epochs will not be guaranteed.
|
||||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||||
@@ -538,7 +539,7 @@ train_on_inputs: false
|
|||||||
# Note that training loss may have an oscillating pattern with this enabled.
|
# Note that training loss may have an oscillating pattern with this enabled.
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
||||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ import pkgutil
|
|||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||||
|
|
||||||
__version__ = "0.9.1"
|
__version__ = "0.9.2"
|
||||||
|
|||||||
@@ -82,6 +82,12 @@ class VllmServeCliArgs:
|
|||||||
"hardware support this feature."
|
"hardware support this feature."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
serve_module: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Module to serve. If not set, the default module will be used."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from trl.scripts.vllm_serve import ScriptArguments
|
from trl.scripts.vllm_serve import ScriptArguments
|
||||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
|
||||||
|
|
||||||
from axolotl.cli.config import load_cfg
|
from axolotl.cli.config import load_cfg
|
||||||
|
|
||||||
@@ -28,6 +27,9 @@ def do_vllm_serve(
|
|||||||
cfg = load_cfg(config)
|
cfg = load_cfg(config)
|
||||||
model = cfg.base_model
|
model = cfg.base_model
|
||||||
|
|
||||||
|
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||||
|
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||||
|
|
||||||
tensor_parallel_size = (
|
tensor_parallel_size = (
|
||||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1057,6 +1057,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
|
||||||
|
|
||||||
if self.cfg.dataset_processes:
|
if self.cfg.dataset_processes:
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
@@ -1186,6 +1188,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "tokenizer" in sig.parameters.keys():
|
if "tokenizer" in sig.parameters.keys():
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|||||||
@@ -5,8 +5,11 @@ from functools import partial
|
|||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
from axolotl.utils.gradient_checkpointing.offload_cpu import (
|
||||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
CPU_Offloaded_Gradient_Checkpointer,
|
||||||
|
)
|
||||||
|
from axolotl.utils.gradient_checkpointing.offload_disk import (
|
||||||
|
Disco,
|
||||||
)
|
)
|
||||||
|
|
||||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||||
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
|
|||||||
decoder_layer, *args, use_reentrant=None
|
decoder_layer, *args, use_reentrant=None
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
if uses_gc_layers(decoder_layer):
|
if uses_gc_layers(decoder_layer):
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
decoder_layer,
|
decoder_layer,
|
||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
return CPU_Offloaded_Gradient_Checkpointer.apply(
|
||||||
|
(
|
||||||
|
decoder_layer.func.__self__
|
||||||
|
if isinstance(decoder_layer, partial)
|
||||||
|
else decoder_layer.__self__
|
||||||
|
),
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hf_grad_checkpoint_disk_offload_wrapper(
|
||||||
|
decoder_layer, *args, use_reentrant=None
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
if uses_gc_layers(decoder_layer):
|
||||||
|
return Disco.apply(
|
||||||
|
decoder_layer,
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Disco.apply(
|
||||||
(
|
(
|
||||||
decoder_layer.func.__self__
|
decoder_layer.func.__self__
|
||||||
if isinstance(decoder_layer, partial)
|
if isinstance(decoder_layer, partial)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Unsloth checkpointing"""
|
"""CPU offloaded checkpointing"""
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
#
|
#
|
||||||
@@ -26,7 +26,7 @@ else:
|
|||||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
torch.autograd.Function
|
torch.autograd.Function
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
531
src/axolotl/utils/gradient_checkpointing/offload_disk.py
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
"""
|
||||||
|
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copyright 2025 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from concurrent.futures import Future
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
# Setup logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DiskOffloadManager:
|
||||||
|
"""
|
||||||
|
Manages offloaded tensors and handles prefetching in a separate thread.
|
||||||
|
Includes synchronization to prevent race conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefetch_size: int = 3,
|
||||||
|
prefetch_to_gpu: bool = True,
|
||||||
|
save_workers: int = 4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
prefetch_size: Maximum number of tensors to prefetch in the background.
|
||||||
|
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
|
||||||
|
save_workers: Maximum number of concurrent save operations.
|
||||||
|
"""
|
||||||
|
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
|
||||||
|
|
||||||
|
# Track tensor paths and their status
|
||||||
|
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
|
||||||
|
self.file_locks: Dict[str, threading.Lock] = (
|
||||||
|
{}
|
||||||
|
) # Maps file_path -> threading.Lock()
|
||||||
|
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
|
||||||
|
self.file_status: Dict[str, str] = {}
|
||||||
|
|
||||||
|
self.max_prefetch = prefetch_size
|
||||||
|
self.prefetch_to_gpu = prefetch_to_gpu
|
||||||
|
|
||||||
|
# Thread synchronization
|
||||||
|
self.manager_lock = threading.RLock() # Used for thread-safe operations
|
||||||
|
|
||||||
|
# Prefetch queue and cache
|
||||||
|
self.prefetch_queue: queue.Queue = queue.Queue()
|
||||||
|
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
|
||||||
|
|
||||||
|
# Save queue and thread pool
|
||||||
|
self.save_queue: queue.Queue = queue.Queue()
|
||||||
|
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
|
||||||
|
self.save_futures: Dict[str, Future] = {}
|
||||||
|
self.save_semaphore = threading.Semaphore(
|
||||||
|
save_workers * 2
|
||||||
|
) # Limit concurrent save operations
|
||||||
|
|
||||||
|
# Start prefetch worker thread
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
# start multiple threads for prefetching
|
||||||
|
self.prefetch_worker_count = 2
|
||||||
|
self.prefetch_workers = []
|
||||||
|
for _ in range(self.prefetch_worker_count):
|
||||||
|
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
|
||||||
|
worker.start()
|
||||||
|
self.prefetch_workers.append(worker)
|
||||||
|
|
||||||
|
# Start save worker thread
|
||||||
|
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
|
||||||
|
self.save_worker.start()
|
||||||
|
self.idx = 0
|
||||||
|
|
||||||
|
atexit.register(self.cleanup)
|
||||||
|
|
||||||
|
def _save_worker(self):
|
||||||
|
"""Background thread that processes the save queue"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
save_item = self.save_queue.get(timeout=0.5)
|
||||||
|
if save_item is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tensor, file_path = save_item
|
||||||
|
|
||||||
|
# Submit the save task to the thread pool
|
||||||
|
future = self.save_pool.submit(
|
||||||
|
self._save_tensor_to_disk, tensor, file_path
|
||||||
|
)
|
||||||
|
with self.manager_lock:
|
||||||
|
self.save_futures[file_path] = future
|
||||||
|
|
||||||
|
self.save_queue.task_done()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||||
|
continue
|
||||||
|
|
||||||
|
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
|
||||||
|
"""Actually save the tensor to disk"""
|
||||||
|
try:
|
||||||
|
# Save tensor to disk
|
||||||
|
cpu_tensor = tensor.detach().cpu()
|
||||||
|
torch.save(cpu_tensor, file_path)
|
||||||
|
del cpu_tensor
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
# Mark file as ready
|
||||||
|
self.file_status[file_path] = "ready"
|
||||||
|
|
||||||
|
# Release semaphore
|
||||||
|
self.save_semaphore.release()
|
||||||
|
|
||||||
|
return True
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f"Error saving tensor to {file_path}: {e}")
|
||||||
|
with self.manager_lock:
|
||||||
|
self.file_status[file_path] = "error"
|
||||||
|
|
||||||
|
# Release semaphore
|
||||||
|
self.save_semaphore.release()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _prefetch_worker(self):
|
||||||
|
"""Background thread that loads tensors from disk ahead of time"""
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
file_path = self.prefetch_queue.get(timeout=0.5)
|
||||||
|
if file_path is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if file is available and not already in cache
|
||||||
|
with self.manager_lock:
|
||||||
|
if (
|
||||||
|
file_path not in self.file_status
|
||||||
|
or self.file_status[file_path] == "deleted"
|
||||||
|
):
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If file is still being saved, wait for it
|
||||||
|
if (
|
||||||
|
self.file_status[file_path] == "saving"
|
||||||
|
and file_path in self.save_futures
|
||||||
|
):
|
||||||
|
# Re-queue this prefetch request with a little delay
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
time.sleep(0.1)
|
||||||
|
self.prefetch_queue.put(file_path)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Mark file as being prefetched
|
||||||
|
self.file_status[file_path] = "prefetching"
|
||||||
|
|
||||||
|
# Load tensor from disk and store in cache
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
if self.prefetch_to_gpu:
|
||||||
|
tensor = torch.load(
|
||||||
|
file_path,
|
||||||
|
map_location=torch.device("cuda"),
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tensor = torch.load(file_path, weights_only=True)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
self.prefetch_cache[file_path] = tensor
|
||||||
|
self.file_status[file_path] = "ready"
|
||||||
|
else:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) != "deleted":
|
||||||
|
logger.warning(
|
||||||
|
f"Prefetch error: File not found {file_path}"
|
||||||
|
)
|
||||||
|
self.file_status[file_path] = "missing"
|
||||||
|
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) != "deleted":
|
||||||
|
logger.warning(f"Prefetch error for {file_path}: {e}")
|
||||||
|
self.file_status[file_path] = "error"
|
||||||
|
|
||||||
|
self.prefetch_queue.task_done()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
time.sleep(0.01) # Small sleep to prevent CPU spinning
|
||||||
|
continue
|
||||||
|
|
||||||
|
def save_tensor(self, tensor: torch.Tensor):
|
||||||
|
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
|
||||||
|
# Generate unique file path
|
||||||
|
self.idx += 1
|
||||||
|
file_path: str = os.path.join(
|
||||||
|
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
# Mark file as being saved
|
||||||
|
self.file_locks[file_path] = threading.Lock()
|
||||||
|
self.file_status[file_path] = "saving"
|
||||||
|
# Add to history
|
||||||
|
self.tensor_paths.append(file_path)
|
||||||
|
|
||||||
|
# Acquire semaphore to limit concurrent save operations
|
||||||
|
self.save_semaphore.acquire() # pylint: disable=consider-using-with
|
||||||
|
# Queue tensor for saving in background
|
||||||
|
self.save_queue.put((tensor.detach(), file_path))
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def wait_for_save(self, file_path, timeout=None) -> None:
|
||||||
|
"""Wait for a tensor to be saved to disk"""
|
||||||
|
start_time = time.time()
|
||||||
|
while timeout is None or time.time() - start_time < timeout:
|
||||||
|
with self.manager_lock:
|
||||||
|
if self.file_status.get(file_path) == "ready":
|
||||||
|
return
|
||||||
|
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
if future.done():
|
||||||
|
return
|
||||||
|
|
||||||
|
# Small sleep to prevent CPU spinning
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
# Timeout
|
||||||
|
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_tensor(self, file_path, target_device="cuda"):
|
||||||
|
"""Load tensor from disk or prefetch cache with proper synchronization"""
|
||||||
|
# Wait for tensor to be saved if it's still in progress
|
||||||
|
self.wait_for_save(file_path)
|
||||||
|
|
||||||
|
tensor = None
|
||||||
|
|
||||||
|
# Try to get from cache first
|
||||||
|
with self.manager_lock:
|
||||||
|
# Check if tensor is already in cache
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
tensor = self.prefetch_cache[file_path]
|
||||||
|
del self.prefetch_cache[file_path]
|
||||||
|
self.file_status[file_path] = "loaded"
|
||||||
|
|
||||||
|
if tensor is not None:
|
||||||
|
# Ensure tensor is on correct device
|
||||||
|
if target_device != "cpu" and tensor.device.type == "cpu":
|
||||||
|
tensor = tensor.to(target_device, non_blocking=True)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
# If not in cache, load directly from disk
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.error(f"File not found for loading: {file_path}")
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
tensor = torch.load(file_path, weights_only=True)
|
||||||
|
|
||||||
|
with self.manager_lock:
|
||||||
|
self.file_status[file_path] = "loaded"
|
||||||
|
|
||||||
|
if target_device != "cpu":
|
||||||
|
tensor = tensor.to(target_device, non_blocking=True)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading tensor from {file_path}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _safe_delete_file(self, file_path):
|
||||||
|
"""Safely delete a file with proper synchronization"""
|
||||||
|
with self.manager_lock:
|
||||||
|
# Make sure any save operation is completed
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
try:
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
del self.save_futures[file_path]
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error canceling save operation for {file_path}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only delete if file exists and is not being prefetched
|
||||||
|
status = self.file_status.get(file_path)
|
||||||
|
if status in ["ready", "loaded", "error", "missing"]:
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
self.file_status[file_path] = "deleted"
|
||||||
|
return True
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(f"Error deleting file {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trigger_prefetch(self, n=None):
|
||||||
|
"""Trigger prefetching of the next N tensors with proper synchronization"""
|
||||||
|
if n is None:
|
||||||
|
n = self.max_prefetch
|
||||||
|
|
||||||
|
prefetch_paths = []
|
||||||
|
with self.manager_lock:
|
||||||
|
# Find files that are ready to be prefetched (not already in cache or being prefetched)
|
||||||
|
for path in reversed(self.tensor_paths):
|
||||||
|
if (
|
||||||
|
path not in self.prefetch_cache
|
||||||
|
and self.file_status.get(path) == "ready"
|
||||||
|
):
|
||||||
|
prefetch_paths.append(path)
|
||||||
|
if len(prefetch_paths) >= n:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Queue files for prefetching
|
||||||
|
for path in prefetch_paths:
|
||||||
|
self.prefetch_queue.put(path)
|
||||||
|
|
||||||
|
def cleanup_tensor(self, file_path: str):
|
||||||
|
"""Clean up a specific tensor file after it's been used"""
|
||||||
|
with self.manager_lock:
|
||||||
|
if file_path in self.tensor_paths:
|
||||||
|
self.tensor_paths.remove(file_path)
|
||||||
|
|
||||||
|
# Remove from prefetch cache if present
|
||||||
|
if file_path in self.prefetch_cache:
|
||||||
|
del self.prefetch_cache[file_path]
|
||||||
|
|
||||||
|
# Remove from save futures if present
|
||||||
|
if file_path in self.save_futures:
|
||||||
|
future = self.save_futures[file_path]
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
del self.save_futures[file_path]
|
||||||
|
|
||||||
|
# Try to delete the file
|
||||||
|
self._safe_delete_file(file_path)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
|
||||||
|
self.stop_event.set()
|
||||||
|
|
||||||
|
# Cancel all pending save operations
|
||||||
|
with self.manager_lock:
|
||||||
|
for _, future in self.save_futures.items():
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
self.save_futures.clear()
|
||||||
|
|
||||||
|
# Drain the save queue
|
||||||
|
while not self.save_queue.empty():
|
||||||
|
try:
|
||||||
|
self.save_queue.get_nowait()
|
||||||
|
self.save_queue.task_done()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Shutdown the save pool
|
||||||
|
self.save_pool.shutdown(wait=False)
|
||||||
|
|
||||||
|
# Join the save worker thread
|
||||||
|
if self.save_worker.is_alive():
|
||||||
|
self.save_worker.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Join the prefetch worker threads
|
||||||
|
for thread in self.prefetch_workers:
|
||||||
|
if thread.is_alive():
|
||||||
|
thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
# Clear cache and remove all temporary files
|
||||||
|
with self.manager_lock:
|
||||||
|
self.prefetch_cache.clear()
|
||||||
|
paths_to_delete = list(self.tensor_paths)
|
||||||
|
self.tensor_paths.clear()
|
||||||
|
|
||||||
|
# Delete all temporary files
|
||||||
|
for path in paths_to_delete:
|
||||||
|
self._safe_delete_file(path)
|
||||||
|
|
||||||
|
# Remove temp directory
|
||||||
|
try:
|
||||||
|
if os.path.exists(self.temp_dir):
|
||||||
|
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class Disco(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||||
|
Advanced disk-based gradient checkpointer with prefetching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Shared manager instance across all checkpointing operations
|
||||||
|
_manager = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
|
||||||
|
"""Get or create the offload manager"""
|
||||||
|
if Disco._manager is None:
|
||||||
|
Disco._manager = DiskOffloadManager(
|
||||||
|
prefetch_size=prefetch_size,
|
||||||
|
prefetch_to_gpu=prefetch_to_gpu,
|
||||||
|
save_workers=save_workers,
|
||||||
|
)
|
||||||
|
return Disco._manager
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_fwd
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
forward_function,
|
||||||
|
hidden_states,
|
||||||
|
*args,
|
||||||
|
prefetch_size=1,
|
||||||
|
prefetch_to_gpu=True,
|
||||||
|
save_workers=4,
|
||||||
|
):
|
||||||
|
"""Forward pass that offloads activations to disk asynchronously"""
|
||||||
|
# Get or create the manager
|
||||||
|
manager = Disco.get_instance(
|
||||||
|
prefetch_size=prefetch_size,
|
||||||
|
prefetch_to_gpu=prefetch_to_gpu,
|
||||||
|
save_workers=save_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save tensor to disk asynchronously
|
||||||
|
file_path = manager.save_tensor(hidden_states)
|
||||||
|
|
||||||
|
# Run forward pass immediately without waiting for save to complete
|
||||||
|
with torch.no_grad():
|
||||||
|
output = forward_function(hidden_states, *args)
|
||||||
|
|
||||||
|
# Store what we need for backward
|
||||||
|
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
|
||||||
|
ctx.file_path = file_path
|
||||||
|
ctx.forward_function = forward_function
|
||||||
|
ctx.args = args
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch_cuda_amp_custom_bwd
|
||||||
|
def backward(ctx, *grad_outputs):
|
||||||
|
"""Backward pass that loads activations from disk with prefetching"""
|
||||||
|
# Get the manager
|
||||||
|
manager = Disco._manager
|
||||||
|
|
||||||
|
# Trigger prefetching for future tensors
|
||||||
|
# This happens at the start of backward, so should have time to complete
|
||||||
|
manager.trigger_prefetch()
|
||||||
|
|
||||||
|
# Load hidden states from disk or prefetch cache
|
||||||
|
file_path = ctx.file_path
|
||||||
|
try:
|
||||||
|
# Ensure the file is saved before we try to load it
|
||||||
|
manager.wait_for_save(file_path)
|
||||||
|
|
||||||
|
hidden_states = manager.load_tensor(file_path)
|
||||||
|
hidden_states.requires_grad = True
|
||||||
|
|
||||||
|
# Compute gradients
|
||||||
|
with torch.enable_grad():
|
||||||
|
output = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
|
||||||
|
# Handle tuple outputs properly
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
if len(grad_outputs) == len(output):
|
||||||
|
torch.autograd.backward(output, grad_outputs)
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, grad_outputs[0])
|
||||||
|
else:
|
||||||
|
torch.autograd.backward(output, grad_outputs[0])
|
||||||
|
|
||||||
|
# Clean up the file after we're done with it
|
||||||
|
manager.cleanup_tensor(file_path)
|
||||||
|
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
None, # forward_function
|
||||||
|
hidden_states.grad, # hidden_states grad
|
||||||
|
)
|
||||||
|
+ (None,) * len(ctx.args) # for each arg
|
||||||
|
+ (
|
||||||
|
None, # prefetch_size
|
||||||
|
None, # prefetch_to_gpu
|
||||||
|
None, # save_workers
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in backward pass: {e}")
|
||||||
|
# Clean up the file even on error
|
||||||
|
manager.cleanup_tensor(file_path)
|
||||||
|
raise
|
||||||
@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
|
|||||||
is_local_main_process,
|
is_local_main_process,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
)
|
)
|
||||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
|
from axolotl.utils.gradient_checkpointing import (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper,
|
||||||
|
hf_grad_checkpoint_offload_wrapper,
|
||||||
|
)
|
||||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||||
|
|
||||||
@@ -603,6 +606,10 @@ class ModelLoader:
|
|||||||
|
|
||||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||||
|
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||||
|
transformers.modeling_utils.checkpoint = (
|
||||||
|
hf_grad_checkpoint_disk_offload_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from multiprocessing import cpu_count
|
from multiprocessing import cpu_count, get_context
|
||||||
from typing import Iterable, Union
|
from typing import Iterable, Union
|
||||||
|
|
||||||
import numba
|
import numba
|
||||||
@@ -78,15 +78,11 @@ def pack_group(
|
|||||||
Returns:
|
Returns:
|
||||||
List of bins, where each bin contains indices of sequences assigned to it
|
List of bins, where each bin contains indices of sequences assigned to it
|
||||||
"""
|
"""
|
||||||
# Get sorting indices and sort lengths in descending order
|
|
||||||
indices = np.argsort(sequence_lengths)[::-1]
|
|
||||||
sorted_lengths = sequence_lengths[indices]
|
|
||||||
|
|
||||||
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
|
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
|
||||||
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
|
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
|
||||||
|
|
||||||
for seq_id, size in enumerate(sorted_lengths):
|
for seq_id, size in enumerate(sequence_lengths):
|
||||||
global_idx = indices[seq_id] + group_offset
|
global_idx = seq_id + group_offset
|
||||||
|
|
||||||
# Try to place sequence in existing bins
|
# Try to place sequence in existing bins
|
||||||
add_new_bin = True
|
add_new_bin = True
|
||||||
@@ -130,6 +126,7 @@ def pack_parallel(
|
|||||||
bin_size: int,
|
bin_size: int,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
safe_mode: bool = True,
|
safe_mode: bool = True,
|
||||||
|
mp_start_method: str | None = "spawn",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Pack sequences into bins using parallel processing
|
Pack sequences into bins using parallel processing
|
||||||
@@ -141,7 +138,9 @@ def pack_parallel(
|
|||||||
bin_size: Maximum number of bins to use
|
bin_size: Maximum number of bins to use
|
||||||
num_processes: Number of parallel processes to use
|
num_processes: Number of parallel processes to use
|
||||||
safe_mode: If True, use a more conservative packing approach
|
safe_mode: If True, use a more conservative packing approach
|
||||||
|
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
|
||||||
|
'spawn' is often safer with Numba/PyTorch.
|
||||||
|
Set to None to use system default.
|
||||||
Returns:
|
Returns:
|
||||||
List of bins, where each bin contains indices of sequences assigned to it
|
List of bins, where each bin contains indices of sequences assigned to it
|
||||||
"""
|
"""
|
||||||
@@ -158,9 +157,33 @@ def pack_parallel(
|
|||||||
|
|
||||||
# Process groups in parallel
|
# Process groups in parallel
|
||||||
all_bins = []
|
all_bins = []
|
||||||
with ProcessPoolExecutor(max_workers=num_processes) as executor:
|
|
||||||
for group_bins in executor.map(_process_group, tasks):
|
mp_ctx = None
|
||||||
|
if mp_start_method:
|
||||||
|
try:
|
||||||
|
mp_ctx = get_context(mp_start_method)
|
||||||
|
except ValueError:
|
||||||
|
LOG.warning(
|
||||||
|
f"Failed to get multiprocessing context '{mp_start_method}'. "
|
||||||
|
f"Falling back to default. Available: {get_context().get_all_start_methods()}"
|
||||||
|
)
|
||||||
|
mp_ctx = (
|
||||||
|
None # Fallback to default context if specified one is not available
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_processes == 1:
|
||||||
|
LOG.debug("Using single process for pack_parallel, running sequentially.")
|
||||||
|
for task_args in tasks:
|
||||||
|
group_bins = _process_group(task_args)
|
||||||
all_bins.extend(group_bins)
|
all_bins.extend(group_bins)
|
||||||
|
else:
|
||||||
|
# Use ProcessPoolExecutor only if num_processes > 1
|
||||||
|
# Pass mp_context if available
|
||||||
|
with ProcessPoolExecutor(
|
||||||
|
max_workers=num_processes, mp_context=mp_ctx
|
||||||
|
) as executor:
|
||||||
|
for group_bins in executor.map(_process_group, tasks):
|
||||||
|
all_bins.extend(group_bins)
|
||||||
|
|
||||||
return all_bins
|
return all_bins
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# torch_dtype: torch.dtype | None
|
# torch_dtype: torch.dtype | None
|
||||||
|
|
||||||
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
|
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
|
||||||
default=False
|
default=False
|
||||||
)
|
)
|
||||||
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
gradient_checkpointing_kwargs: dict[str, Any] | None = None
|
||||||
@@ -1149,16 +1149,28 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
# @model_validator(mode="before")
|
||||||
|
# @classmethod
|
||||||
|
# def check_grpo_peft_liger(cls, data):
|
||||||
|
# if (
|
||||||
|
# data.get("rl") == "grpo"
|
||||||
|
# and data.get("trl", {})
|
||||||
|
# and data.get("trl").get("use_liger_loss")
|
||||||
|
# and data.get("adapter")
|
||||||
|
# ):
|
||||||
|
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
||||||
|
# return data
|
||||||
|
#
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_grpo_peft_liger(cls, data):
|
def check_grpo_liger_sequence_parallel(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("rl") == "grpo"
|
data.get("rl") == "grpo"
|
||||||
and data.get("trl", {})
|
and data.get("trl", {})
|
||||||
and data.get("trl").get("use_liger_loss")
|
and data.get("trl").get("use_liger_loss")
|
||||||
and data.get("adapter")
|
and data.get("sequence_parallel_degree", 1) > 1
|
||||||
):
|
):
|
||||||
raise ValueError("PEFT + GRPO + Liger is not yet supported")
|
raise ValueError("GRPO + SP + Liger not currently supported")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
@@ -1345,6 +1357,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
):
|
):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||||
|
if data.get("lora_dropout") != 0:
|
||||||
|
return data
|
||||||
|
|
||||||
# Check multi-GPU compatibility
|
# Check multi-GPU compatibility
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class TestKnowledgeDistillation:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -121,5 +121,5 @@ class TestKnowledgeDistillation:
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="flaky test")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -227,7 +228,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
|
|
||||||
current_env = os.environ.copy()
|
current_env = os.environ.copy()
|
||||||
env = {
|
env = {
|
||||||
"NCCL_P2P_LEVEL": "NVL",
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||||
@@ -257,7 +258,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
f"{get_torch_dist_unique_port()}",
|
f"{get_torch_dist_unique_port()}",
|
||||||
],
|
],
|
||||||
env={
|
env={
|
||||||
"NCCL_P2P_LEVEL": "NVL",
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
"NCCL_DEBUG": "INFO",
|
"NCCL_DEBUG": "INFO",
|
||||||
**current_env,
|
**current_env,
|
||||||
},
|
},
|
||||||
@@ -265,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
finally:
|
finally:
|
||||||
recursive_kill(vllm_process)
|
recursive_kill(vllm_process)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="flaky test")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_gpus",
|
"num_gpus",
|
||||||
[1, 2],
|
[1, 2],
|
||||||
@@ -320,7 +322,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
|
|
||||||
current_env = os.environ.copy()
|
current_env = os.environ.copy()
|
||||||
env = {
|
env = {
|
||||||
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
|
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||||
**current_env,
|
**current_env,
|
||||||
"CUDA_VISIBLE_DEVICES": "1",
|
"CUDA_VISIBLE_DEVICES": "1",
|
||||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||||
@@ -350,7 +352,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
f"{get_torch_dist_unique_port()}",
|
f"{get_torch_dist_unique_port()}",
|
||||||
],
|
],
|
||||||
env={
|
env={
|
||||||
"NCCL_P2P_LEVEL": "NVL",
|
"NCCL_P2P_LEVEL": "LOC",
|
||||||
"NCCL_DEBUG": "INFO",
|
"NCCL_DEBUG": "INFO",
|
||||||
**current_env,
|
**current_env,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -57,9 +57,9 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"save_steps": 10,
|
"save_steps": 3,
|
||||||
"eval_steps": 10,
|
"eval_steps": 4,
|
||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -105,9 +105,9 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"save_steps": 10,
|
"save_steps": 3,
|
||||||
"eval_steps": 10,
|
"eval_steps": 4,
|
||||||
"fp16": True,
|
"fp16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
|
|||||||
E2E tests for activation checkpointing
|
E2E tests for activation checkpointing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gradient_checkpointing",
|
||||||
|
["offload", "offload_disk"],
|
||||||
|
)
|
||||||
def test_activation_checkpointing_offload(
|
def test_activation_checkpointing_offload(
|
||||||
self,
|
self,
|
||||||
temp_dir,
|
temp_dir,
|
||||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
gradient_checkpointing,
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
|
|||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"gradient_checkpointing": "offload",
|
"gradient_checkpointing": gradient_checkpointing,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -57,9 +57,9 @@ class TestMistral(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"save_steps": 10,
|
"save_steps": 3,
|
||||||
"eval_steps": 10,
|
"eval_steps": 4,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -99,9 +99,9 @@ class TestMistral(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"save_steps": 10,
|
"save_steps": 3,
|
||||||
"eval_steps": 10,
|
"eval_steps": 4,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -54,9 +54,9 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_bnb_8bit",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"save_steps": 10,
|
"save_steps": 3,
|
||||||
"eval_steps": 10,
|
"eval_steps": 4,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -93,9 +93,9 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_bnb_8bit",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"save_steps": 10,
|
"save_steps": 3,
|
||||||
"eval_steps": 10,
|
"eval_steps": 4,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,9 +56,9 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_bnb_8bit",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"eval_steps": 10,
|
"eval_steps": 3,
|
||||||
"save_steps": 10,
|
"save_steps": 4,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -108,9 +108,9 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_bnb_8bit",
|
"optimizer": "adamw_bnb_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"max_steps": 20,
|
"max_steps": 5,
|
||||||
"eval_steps": 10,
|
"eval_steps": 3,
|
||||||
"save_steps": 10,
|
"save_steps": 4,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -106,3 +106,4 @@ class TestBatchedSamplerPacking:
|
|||||||
|
|
||||||
original_idxs = set(range(len(train_dataset)))
|
original_idxs = set(range(len(train_dataset)))
|
||||||
assert original_idxs == set(batch_idxs)
|
assert original_idxs == set(batch_idxs)
|
||||||
|
assert len(batch_idxs) == len(set(batch_idxs))
|
||||||
|
|||||||
Reference in New Issue
Block a user