Compare commits

...

7 Commits

Author SHA1 Message Date
Wing Lian
54bbc9bb72 set v0.9.2 version for tag
Some checks failed
ci-cd / build-axolotl (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl (vllm, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Create Release (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, 3.11, 2.5.1) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 124, 12.4.1, true, 3.11, 2.6.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud (<nil>, 126, 12.6.3, 3.11, 2.7.0) (push) Has been cancelled
ci-cd / build-axolotl-cloud-no-tmux (<nil>, 124, 12.4.1, 3.11, 2.6.0) (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
2025-05-13 17:52:33 -04:00
Wing Lian
5aefebe1fe Activation checkpointing with offloading to disk with prefetch (#2663)
* offload activations to disk instead of CPU RAM

* add prefetch

* Disco :dance:

* include offload_disk in e2e test for AC

* document and make sure to cleanup

* fix annotation to match docs

* fix docs build

* address PR feedback
2025-05-13 17:06:31 -04:00
Wing Lian
5a36b6ff2d Atropos support (#2666) [skip ci]
* allow peft+liger+grpo and custom vllm serve for atropos support

* set trainer class for RL
2025-05-13 17:06:05 -04:00
NanoCode012
224da88fa2 fix: disable auto lora kernel if dropout nonzero (#2655) [skip ci]
* fix: disable auto lora kernel if dropout nonzero

* Add comment from PR feedback

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-13 17:05:20 -04:00
Wing Lian
493eb8e5c6 update doc and use P2P=LOC for brittle grpo test (#2649)
* update doc and skip brittle grpo test

* fix the path to run the multigpu tests

* increase timeout, use LOC instead of NVL

* typo

* use hf cache from s3 backed cloudfront

* mark grpo as flaky test dues to vllm start
2025-05-13 17:05:11 -04:00
Wing Lian
4780ac7c4d guard on deleting secrets from env (#2653) [skip ci] 2025-05-13 17:03:27 -04:00
Wing Lian
cf69de2eb9 Various fixes for CI, save_only_model for RL, prevent packing multiprocessing deadlocks (#2661)
* lean mistral ft tests, remove e2e torch 2.4.1 test

* make sure to pass save_only_model for RL

* more tests to make ci leaner, add cleanup to modal ci

* fix module for import in e2e tests

* use mp spawn to prevent deadlocks with packing

* make sure cleanup shell script is executable when cloned out
2025-05-13 17:03:08 -04:00
28 changed files with 935 additions and 221 deletions

View File

@@ -3,7 +3,7 @@ name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
- 'tests/e2e/multigpu/**.py'
- 'requirements.txt'
- 'setup.py'
- 'pyproject.toml'

View File

@@ -44,96 +44,102 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
# preload-cache:
# name: Preload HF cache
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# python_version: ["3.11"]
# pytorch_version: ["2.6.0"]
# timeout-minutes: 20
#
# env:
# AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
#
# steps:
# - name: Check out repository code
# uses: actions/checkout@v4
#
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
#
# - name: Restore Cache from S3
# id: hf-cache-restore-s3
# run: |
# mkdir -p /home/runner/.cache/huggingface/hub
# curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
#
# - name: Setup Python
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python_version }}
# cache: 'pip' # caching pip dependencies
#
# - name: upgrade pip
# run: |
# pip3 install --upgrade pip
# pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
#
# - name: Install PyTorch
# run: |
# pip3 install torch==${{ matrix.pytorch_version }}
#
# - name: Install dependencies
# run: |
# pip3 show torch
# pip3 install --no-build-isolation -U -e .
# python scripts/unsloth_install.py | sh
# python scripts/cutcrossentropy_install.py | sh
# pip3 install -r requirements-dev.txt -r requirements-tests.txt
#
# - name: Make sure PyTorch version wasn't clobbered
# run: |
# python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
#
# - name: Ensure axolotl CLI was installed
# run: |
# axolotl --help
#
# - name: Pre-Download dataset fixture
# run: |
# huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
#
# - name: Run tests
# run: |
# pytest -v tests/conftest.py
#
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v5
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# files: ./coverage.xml
# flags: unittests,pytorch-${{ matrix.pytorch_version }}
# fail_ci_if_error: false
#
# - name: cleanup pip cache
# run: |
# find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
#
# - name: Save HF cache
# id: hf-cache
# uses: actions/cache/save@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
@@ -145,14 +151,20 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -210,7 +222,7 @@ jobs:
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
needs: [preload-cache]
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
@@ -222,14 +234,20 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
# - name: Restore HF cache
# id: hf-cache-restore
# uses: actions/cache/restore@v4
# with:
# path: |
# /home/runner/.cache/huggingface/hub/datasets--*
# /home/runner/.cache/huggingface/hub/models--*
# key: ${{ runner.os }}-hf-hub-cache-v2
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -365,3 +383,43 @@ jobs:
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
docker-e2e-cleanup:
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [docker-e2e-tests]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras: vllm
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

View File

@@ -57,8 +57,10 @@ async def handler(job):
logger.info("Training Complete.")
# Cleanup
del os.environ["WANDB_API_KEY"]
del os.environ["HF_TOKEN"]
if "WANDB_API_KEY" in os.environ:
del os.environ["WANDB_API_KEY"]
if "HF_TOKEN" in os.environ:
del os.environ["HF_TOKEN"]
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})

View File

@@ -124,7 +124,8 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:

0
cicd/__init__.py Normal file
View File

View File

@@ -18,7 +18,7 @@ pytest -v --durations=10 \
--cov-append
# Run patched tests excluding lora kernels with coverage append
pytest -v --durations=10 \
pytest --full-trace -vvv --durations=10 \
--ignore=tests/e2e/patched/lora_kernels \
/workspace/axolotl/tests/e2e/patched \
--cov=axolotl \

19
cicd/cleanup.py Normal file
View File

@@ -0,0 +1,19 @@
"""Modal app to run axolotl GPU cleanup"""
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cleanup():
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
@app.local_entrypoint()
def main():
cleanup.remote()

6
cicd/cleanup.sh Executable file
View File

@@ -0,0 +1,6 @@
#!/bin/bash
set -e
# cleanup old cache files for datasets processing and intermediate mappings
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;

View File

@@ -1,75 +1,12 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
@app.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
timeout=90 * 60, # 90 min
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,

66
cicd/single_gpu.py Normal file
View File

@@ -0,0 +1,66 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -19,7 +19,7 @@ coverage:
if_no_uploads: error
if_not_found: success
if_ci_failed: error
only_pulls: false
only_pulls: true
flags: null
paths: null
patch:

View File

@@ -505,6 +505,7 @@ save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of eac
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
@@ -538,7 +539,7 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.9.1.post1"
__version__ = "0.9.2"

View File

@@ -82,6 +82,12 @@ class VllmServeCliArgs:
"hardware support this feature."
},
)
serve_module: Optional[str] = field(
default=None,
metadata={
"help": "Module to serve. If not set, the default module will be used."
},
)
@dataclass

View File

@@ -6,7 +6,6 @@ from pathlib import Path
from typing import Union
from trl.scripts.vllm_serve import ScriptArguments
from trl.scripts.vllm_serve import main as vllm_serve_main
from axolotl.cli.config import load_cfg
@@ -28,6 +27,9 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
)

View File

@@ -1057,6 +1057,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["save_only_model"] = self.cfg.save_only_model
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
@@ -1186,6 +1188,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():
dpo_trainer_kwargs["tokenizer"] = self.tokenizer

View File

@@ -5,8 +5,11 @@ from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
from axolotl.utils.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
Disco,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)
def hf_grad_checkpoint_disk_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Disco.apply(
decoder_layer,
*args,
)
return Disco.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)

View File

@@ -1,4 +1,4 @@
"""Unsloth checkpointing"""
"""CPU offloaded checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
@@ -26,7 +26,7 @@ else:
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""

View File

@@ -0,0 +1,531 @@
"""
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
"""
# Copyright 2025 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import concurrent.futures
import logging
import os
import queue
import shutil
import tempfile
import threading
import time
import uuid
from collections import deque
from concurrent.futures import Future
from typing import Dict
import torch
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger
logger = logging.getLogger(__name__)
class DiskOffloadManager:
"""
Manages offloaded tensors and handles prefetching in a separate thread.
Includes synchronization to prevent race conditions.
"""
def __init__(
self,
prefetch_size: int = 3,
prefetch_to_gpu: bool = True,
save_workers: int = 4,
):
"""
Args:
prefetch_size: Maximum number of tensors to prefetch in the background.
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
save_workers: Maximum number of concurrent save operations.
"""
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
# Track tensor paths and their status
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
self.file_locks: Dict[str, threading.Lock] = (
{}
) # Maps file_path -> threading.Lock()
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
self.file_status: Dict[str, str] = {}
self.max_prefetch = prefetch_size
self.prefetch_to_gpu = prefetch_to_gpu
# Thread synchronization
self.manager_lock = threading.RLock() # Used for thread-safe operations
# Prefetch queue and cache
self.prefetch_queue: queue.Queue = queue.Queue()
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
# Save queue and thread pool
self.save_queue: queue.Queue = queue.Queue()
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
self.save_futures: Dict[str, Future] = {}
self.save_semaphore = threading.Semaphore(
save_workers * 2
) # Limit concurrent save operations
# Start prefetch worker thread
self.stop_event = threading.Event()
# start multiple threads for prefetching
self.prefetch_worker_count = 2
self.prefetch_workers = []
for _ in range(self.prefetch_worker_count):
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
worker.start()
self.prefetch_workers.append(worker)
# Start save worker thread
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
self.save_worker.start()
self.idx = 0
atexit.register(self.cleanup)
def _save_worker(self):
"""Background thread that processes the save queue"""
while not self.stop_event.is_set():
try:
save_item = self.save_queue.get(timeout=0.5)
if save_item is None:
continue
tensor, file_path = save_item
# Submit the save task to the thread pool
future = self.save_pool.submit(
self._save_tensor_to_disk, tensor, file_path
)
with self.manager_lock:
self.save_futures[file_path] = future
self.save_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
"""Actually save the tensor to disk"""
try:
# Save tensor to disk
cpu_tensor = tensor.detach().cpu()
torch.save(cpu_tensor, file_path)
del cpu_tensor
with self.manager_lock:
# Mark file as ready
self.file_status[file_path] = "ready"
# Release semaphore
self.save_semaphore.release()
return True
except FileNotFoundError as e:
logger.error(f"Error saving tensor to {file_path}: {e}")
with self.manager_lock:
self.file_status[file_path] = "error"
# Release semaphore
self.save_semaphore.release()
return False
def _prefetch_worker(self):
"""Background thread that loads tensors from disk ahead of time"""
while not self.stop_event.is_set():
try:
file_path = self.prefetch_queue.get(timeout=0.5)
if file_path is None:
continue
# Check if file is available and not already in cache
with self.manager_lock:
if (
file_path not in self.file_status
or self.file_status[file_path] == "deleted"
):
self.prefetch_queue.task_done()
if file_path in self.prefetch_cache:
self.prefetch_queue.task_done()
continue
# If file is still being saved, wait for it
if (
self.file_status[file_path] == "saving"
and file_path in self.save_futures
):
# Re-queue this prefetch request with a little delay
self.prefetch_queue.task_done()
time.sleep(0.1)
self.prefetch_queue.put(file_path)
continue
# Mark file as being prefetched
self.file_status[file_path] = "prefetching"
# Load tensor from disk and store in cache
try:
if os.path.exists(file_path):
if self.prefetch_to_gpu:
tensor = torch.load(
file_path,
map_location=torch.device("cuda"),
weights_only=True,
)
else:
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.prefetch_cache[file_path] = tensor
self.file_status[file_path] = "ready"
else:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(
f"Prefetch error: File not found {file_path}"
)
self.file_status[file_path] = "missing"
except FileNotFoundError as e:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(f"Prefetch error for {file_path}: {e}")
self.file_status[file_path] = "error"
self.prefetch_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def save_tensor(self, tensor: torch.Tensor):
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
# Generate unique file path
self.idx += 1
file_path: str = os.path.join(
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
)
with self.manager_lock:
# Mark file as being saved
self.file_locks[file_path] = threading.Lock()
self.file_status[file_path] = "saving"
# Add to history
self.tensor_paths.append(file_path)
# Acquire semaphore to limit concurrent save operations
self.save_semaphore.acquire() # pylint: disable=consider-using-with
# Queue tensor for saving in background
self.save_queue.put((tensor.detach(), file_path))
return file_path
def wait_for_save(self, file_path, timeout=None) -> None:
"""Wait for a tensor to be saved to disk"""
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
with self.manager_lock:
if self.file_status.get(file_path) == "ready":
return
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
return
if file_path in self.save_futures:
future = self.save_futures[file_path]
if future.done():
return
# Small sleep to prevent CPU spinning
time.sleep(0.01)
# Timeout
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
return
def load_tensor(self, file_path, target_device="cuda"):
"""Load tensor from disk or prefetch cache with proper synchronization"""
# Wait for tensor to be saved if it's still in progress
self.wait_for_save(file_path)
tensor = None
# Try to get from cache first
with self.manager_lock:
# Check if tensor is already in cache
if file_path in self.prefetch_cache:
tensor = self.prefetch_cache[file_path]
del self.prefetch_cache[file_path]
self.file_status[file_path] = "loaded"
if tensor is not None:
# Ensure tensor is on correct device
if target_device != "cpu" and tensor.device.type == "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
# If not in cache, load directly from disk
try:
if not os.path.exists(file_path):
logger.error(f"File not found for loading: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.file_status[file_path] = "loaded"
if target_device != "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
except Exception as e:
logger.error(f"Error loading tensor from {file_path}: {e}")
raise
def _safe_delete_file(self, file_path):
"""Safely delete a file with proper synchronization"""
with self.manager_lock:
# Make sure any save operation is completed
if file_path in self.save_futures:
future = self.save_futures[file_path]
try:
if not future.done():
future.cancel()
del self.save_futures[file_path]
except FileNotFoundError as e:
logger.warning(
f"Error canceling save operation for {file_path}: {e}"
)
# Only delete if file exists and is not being prefetched
status = self.file_status.get(file_path)
if status in ["ready", "loaded", "error", "missing"]:
try:
if os.path.exists(file_path):
os.remove(file_path)
self.file_status[file_path] = "deleted"
return True
except FileNotFoundError as e:
logger.warning(f"Error deleting file {file_path}: {e}")
return False
def trigger_prefetch(self, n=None):
"""Trigger prefetching of the next N tensors with proper synchronization"""
if n is None:
n = self.max_prefetch
prefetch_paths = []
with self.manager_lock:
# Find files that are ready to be prefetched (not already in cache or being prefetched)
for path in reversed(self.tensor_paths):
if (
path not in self.prefetch_cache
and self.file_status.get(path) == "ready"
):
prefetch_paths.append(path)
if len(prefetch_paths) >= n:
break
# Queue files for prefetching
for path in prefetch_paths:
self.prefetch_queue.put(path)
def cleanup_tensor(self, file_path: str):
"""Clean up a specific tensor file after it's been used"""
with self.manager_lock:
if file_path in self.tensor_paths:
self.tensor_paths.remove(file_path)
# Remove from prefetch cache if present
if file_path in self.prefetch_cache:
del self.prefetch_cache[file_path]
# Remove from save futures if present
if file_path in self.save_futures:
future = self.save_futures[file_path]
if not future.done():
future.cancel()
del self.save_futures[file_path]
# Try to delete the file
self._safe_delete_file(file_path)
def cleanup(self):
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
self.stop_event.set()
# Cancel all pending save operations
with self.manager_lock:
for _, future in self.save_futures.items():
if not future.done():
future.cancel()
self.save_futures.clear()
# Drain the save queue
while not self.save_queue.empty():
try:
self.save_queue.get_nowait()
self.save_queue.task_done()
except queue.Empty:
break
# Shutdown the save pool
self.save_pool.shutdown(wait=False)
# Join the save worker thread
if self.save_worker.is_alive():
self.save_worker.join(timeout=2.0)
# Join the prefetch worker threads
for thread in self.prefetch_workers:
if thread.is_alive():
thread.join(timeout=2.0)
# Clear cache and remove all temporary files
with self.manager_lock:
self.prefetch_cache.clear()
paths_to_delete = list(self.tensor_paths)
self.tensor_paths.clear()
# Delete all temporary files
for path in paths_to_delete:
self._safe_delete_file(path)
# Remove temp directory
try:
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
except FileNotFoundError as e:
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
class Disco(torch.autograd.Function):
"""
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
Advanced disk-based gradient checkpointer with prefetching.
"""
# Shared manager instance across all checkpointing operations
_manager = None
@staticmethod
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
"""Get or create the offload manager"""
if Disco._manager is None:
Disco._manager = DiskOffloadManager(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
return Disco._manager
@staticmethod
@torch_cuda_amp_custom_fwd
def forward(
ctx,
forward_function,
hidden_states,
*args,
prefetch_size=1,
prefetch_to_gpu=True,
save_workers=4,
):
"""Forward pass that offloads activations to disk asynchronously"""
# Get or create the manager
manager = Disco.get_instance(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
# Save tensor to disk asynchronously
file_path = manager.save_tensor(hidden_states)
# Run forward pass immediately without waiting for save to complete
with torch.no_grad():
output = forward_function(hidden_states, *args)
# Store what we need for backward
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
ctx.file_path = file_path
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch_cuda_amp_custom_bwd
def backward(ctx, *grad_outputs):
"""Backward pass that loads activations from disk with prefetching"""
# Get the manager
manager = Disco._manager
# Trigger prefetching for future tensors
# This happens at the start of backward, so should have time to complete
manager.trigger_prefetch()
# Load hidden states from disk or prefetch cache
file_path = ctx.file_path
try:
# Ensure the file is saved before we try to load it
manager.wait_for_save(file_path)
hidden_states = manager.load_tensor(file_path)
hidden_states.requires_grad = True
# Compute gradients
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# Handle tuple outputs properly
if isinstance(output, tuple):
if len(grad_outputs) == len(output):
torch.autograd.backward(output, grad_outputs)
else:
torch.autograd.backward(output, grad_outputs[0])
else:
torch.autograd.backward(output, grad_outputs[0])
# Clean up the file after we're done with it
manager.cleanup_tensor(file_path)
return (
(
None, # forward_function
hidden_states.grad, # hidden_states grad
)
+ (None,) * len(ctx.args) # for each arg
+ (
None, # prefetch_size
None, # prefetch_to_gpu
None, # save_workers
)
)
except Exception as e:
logger.error(f"Error in backward pass: {e}")
# Clean up the file even on error
manager.cleanup_tensor(file_path)
raise

View File

@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
is_local_main_process,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
hf_grad_checkpoint_offload_wrapper,
)
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -603,6 +606,10 @@ class ModelLoader:
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
if self.cfg.flash_attention:
self.patch_attention()

View File

@@ -6,7 +6,7 @@ into fixed-capacity batches to optimize memory usage and training throughput.
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from multiprocessing import cpu_count, get_context
from typing import Iterable, Union
import numba
@@ -126,6 +126,7 @@ def pack_parallel(
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
mp_start_method: str | None = "spawn",
):
"""
Pack sequences into bins using parallel processing
@@ -137,7 +138,9 @@ def pack_parallel(
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
mp_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver').
'spawn' is often safer with Numba/PyTorch.
Set to None to use system default.
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
@@ -154,9 +157,33 @@ def pack_parallel(
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
mp_ctx = None
if mp_start_method:
try:
mp_ctx = get_context(mp_start_method)
except ValueError:
LOG.warning(
f"Failed to get multiprocessing context '{mp_start_method}'. "
f"Falling back to default. Available: {get_context().get_all_start_methods()}"
)
mp_ctx = (
None # Fallback to default context if specified one is not available
)
if num_processes == 1:
LOG.debug("Using single process for pack_parallel, running sequentially.")
for task_args in tasks:
group_bins = _process_group(task_args)
all_bins.extend(group_bins)
else:
# Use ProcessPoolExecutor only if num_processes > 1
# Pass mp_context if available
with ProcessPoolExecutor(
max_workers=num_processes, mp_context=mp_ctx
) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
return all_bins

View File

@@ -178,7 +178,7 @@ class AxolotlInputConfig(
# torch_dtype: torch.dtype | None
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
@@ -1149,16 +1149,28 @@ class AxolotlInputConfig(
return data
# @model_validator(mode="before")
# @classmethod
# def check_grpo_peft_liger(cls, data):
# if (
# data.get("rl") == "grpo"
# and data.get("trl", {})
# and data.get("trl").get("use_liger_loss")
# and data.get("adapter")
# ):
# raise ValueError("PEFT + GRPO + Liger is not yet supported")
# return data
#
@model_validator(mode="before")
@classmethod
def check_grpo_peft_liger(cls, data):
def check_grpo_liger_sequence_parallel(cls, data):
if (
data.get("rl") == "grpo"
and data.get("trl", {})
and data.get("trl").get("use_liger_loss")
and data.get("adapter")
and data.get("sequence_parallel_degree", 1) > 1
):
raise ValueError("PEFT + GRPO + Liger is not yet supported")
raise ValueError("GRPO + SP + Liger not currently supported")
return data
@model_validator(mode="after")
@@ -1345,6 +1357,10 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1

View File

@@ -166,6 +166,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -227,7 +228,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
@@ -257,7 +258,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
@@ -265,6 +266,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -320,7 +322,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "NVL", # nccl can be brittle, assume P2P isn't reliable
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
@@ -350,7 +352,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "NVL",
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},

View File

@@ -57,9 +57,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
}
)
@@ -105,9 +105,9 @@ class Test4dMultipackLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
}
)

View File

@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
E2E tests for activation checkpointing
"""
@pytest.mark.parametrize(
"gradient_checkpointing",
["offload", "offload_disk"],
)
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
gradient_checkpointing,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
"gradient_checkpointing": gradient_checkpointing,
}
)

View File

@@ -57,9 +57,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)
@@ -99,9 +99,9 @@ class TestMistral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)

View File

@@ -54,9 +54,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)
@@ -93,9 +93,9 @@ class TestMixtral(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 5,
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
}
)

View File

@@ -56,9 +56,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
}
)
@@ -108,9 +108,9 @@ class TestPhiMultipack(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"max_steps": 5,
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
}
)