Compare commits

..

30 Commits

Author SHA1 Message Date
Wing Lian
9bdf4b1c23 improve handling and error if fa3 requested but not installeD 2025-05-19 10:11:14 -07:00
Wing Lian
d6f64a3684 handle args to drop dropout 2025-05-18 15:17:40 -07:00
Wing Lian
0735454782 move fa3 tests to multigpu since we only run those on hopper 2025-05-18 15:17:39 -07:00
Wing Lian
bb6464c4c6 use get_device_capability since CI setting in cfg is unreliable 2025-05-18 15:17:39 -07:00
Wing Lian
323a9cb153 handle return sig change for fa3 2025-05-18 15:17:39 -07:00
Wing Lian
b22150751f check for fa first 2025-05-18 15:17:39 -07:00
Wing Lian
8c4bc59bfc fa3 doesn't support dropout_p, fix unpatching 2025-05-18 15:17:39 -07:00
Wing Lian
a064f1c9b4 ci for fa3 2025-05-18 15:17:39 -07:00
Wing Lian
fb5ef6d445 use updated package name for fa3 2025-05-18 15:17:38 -07:00
Wing Lian
34b68ddaae curl with apt instead of pip 2025-05-18 15:17:38 -07:00
Wing Lian
9a3d0c919b make sure curl is installed 2025-05-18 15:17:38 -07:00
Wing Lian
bd34d0b861 install for hopper from pre-built wheel 2025-05-18 15:17:38 -07:00
Wing Lian
37220ab90a install pybind11 for fa3 build 2025-05-18 15:17:38 -07:00
Wing Lian
e1b74d710b update docker args to minimums used and use MAX_JOBS already set as arg 2025-05-18 15:17:38 -07:00
Wing Lian
79daf5b934 reduce max jobs for build of fa3 2025-05-18 15:17:38 -07:00
Wing Lian
ddd7c55576 build hopper w fa3 on torch 2.6 2025-05-18 15:17:37 -07:00
Wing Lian
65c6c98a76 whitespace fix in dockerfile 2025-05-18 15:17:37 -07:00
Wing Lian
4ef2e8293f fix the bash in docker base 2025-05-18 15:17:37 -07:00
Wing Lian
c126d5cd04 fix suffix for tag 2025-05-18 15:17:37 -07:00
Wing Lian
9b0be4f15c fix 12.8 image and add flash-attn v3 hopper base image 2025-05-18 15:17:37 -07:00
Wing Lian
a27b909c5c GRPO fixes (peft) (#2676)
* don't set peft_config on grpo to prevent double peft wrap

* remove overrides needed to support bug

* fix grpo tests

* require more CPU for multigpu to help with torch compile for vllm
2025-05-16 15:47:03 -04:00
xzuyn
6cb07b9d12 Fix for setting adam_beta3 and adam_epsilon2 for CAME Optimizer (#2654) [skip ci]
* make setting `adam_beta3` and `adam_epsilon2` work correctly

* update config docs so users know args are specific to CAME optim

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:50 -04:00
C080
288653adb6 Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifa… (#2675) [skip ci]
* Fix: Make MLflow config artifact logging respect hf_mlflow_log_artifacts setting

* cleanup and lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-16 15:46:31 -04:00
NanoCode012
3a5b495a74 Fix: improve doc on merge/inference cli visibility (#2674)
* feat: improve visibility for merge doc

* feat: add tip on reuse config between modes
2025-05-16 13:07:40 -04:00
xzuyn
f661858fc4 Print dataset name (#2668) [skip ci] 2025-05-16 13:06:58 -04:00
Eric Meier
c837c4a424 Add missing init file to liger plugin (#2670) [skip ci] 2025-05-16 13:06:46 -04:00
michelyang
c9797de6bb Add num_proc to fix data set slow processing issue (#2681) [skip ci] 2025-05-16 13:06:20 -04:00
Wing Lian
8f8a7afb05 Add ci and images for CUDA 12.8 for B200s (#2683) [skip ci]
* Add ci and images for CUDA 12.8 for B200s

* add comments explaining CI [skip e2e]
2025-05-16 13:06:08 -04:00
NanoCode012
86472715da fix: remove doc string imports in monkeypatches (#2671) [skip ci] 2025-05-16 13:05:55 -04:00
Wing Lian
c0a0c7534c Activation checkpointing with offloading to disk with prefetch (#2663)
* offload activations to disk instead of CPU RAM

* add prefetch

* Disco :dance:

* include offload_disk in e2e test for AC

* document and make sure to cleanup

* fix annotation to match docs

* fix docs build

* address PR feedback
2025-05-13 16:39:39 -04:00
41 changed files with 777 additions and 657 deletions

View File

@@ -47,11 +47,18 @@ jobs:
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
suffix: "-hopper"
torch_cuda_arch_list: "9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -87,7 +94,7 @@ jobs:
context: .
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}${{ matrix.suffix || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}

View File

@@ -31,6 +31,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -94,6 +99,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -32,21 +32,25 @@ jobs:
pytorch: 2.6.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
suffix: "-hopper"
num_gpus: 2
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -68,7 +72,6 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |

View File

@@ -295,6 +295,7 @@ jobs:
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
@@ -341,6 +342,8 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
@@ -365,6 +368,12 @@ jobs:
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -139,7 +139,8 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.gradient_checkpointing.unsloth
- utils.gradient_checkpointing.offload_cpu
- utils.gradient_checkpointing.offload_disk
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:

View File

@@ -32,6 +32,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "126" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
image=cicd_image,
gpu=GPU_CONFIG,
timeout=90 * 60,
cpu=8.0,
cpu=16.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)

View File

@@ -1,5 +1,5 @@
ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG CUDA_VERSION="12.4.1"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
@@ -7,16 +7,16 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.5.1"
ARG CUDA="124"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
@@ -38,6 +38,10 @@ RUN git lfs install --skip-repo && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \
curl -L -O https://d1dttdx32dkk5p.cloudfront.net/fa3/cu${CUDA}/torch-${PYTORCH_VERSION}/flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
pip3 install --no-cache-dir flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
rm flash_attn_3-3.0.0b1-cp311-cp311-linux_x86_64.whl; \
elif [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -539,7 +539,7 @@ train_on_inputs: false
# Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false
# additional kwargs to pass to the trainer for gradient checkpointing
@@ -633,7 +633,9 @@ weight_decay:
# adamw hyperparams
adam_beta1:
adam_beta2:
adam_beta3: # only used for CAME Optimizer
adam_epsilon:
adam_epsilon2: # only used for CAME Optimizer
# Gradient clipping max norm
max_grad_norm:

View File

@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
format them.
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
format):
```json
@@ -120,6 +120,12 @@ axolotl train my_training.yml
## Common Tasks {#sec-common-tasks}
::: {.callout-tip}
The same yaml file is used for training, inference, and merging.
:::
### Testing Your Model {#sec-testing}
After training, test your model:
@@ -128,6 +134,16 @@ After training, test your model:
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
```
More details can be found in [Inference](inference.qmd).
### Using a UI {#sec-ui}
Launch a Gradio interface:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
```
### Preprocessing Data {#sec-preprocessing}
For large datasets, preprocess first:
@@ -136,14 +152,22 @@ For large datasets, preprocess first:
axolotl preprocess my_training.yml
```
### Using a UI {#sec-ui}
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
Launch a Gradio interface:
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
### Merging LoRA weights {#sec-merging-lora}
To merge the LoRA weights back into the base model, run:
```bash
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
```
The merged model will be saved in the `{output_dir}/merged` directory.
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
## Next Steps {#sec-next-steps}
Now that you have the basics, you might want to:
@@ -156,6 +180,7 @@ Now that you have the basics, you might want to:
Check our other guides for details on these topics:
- [Configuration Guide](config.qmd) - Full configuration options
- [Dataset Loading](dataset-loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

View File

@@ -342,13 +342,6 @@ def delinearize_llama4(model: str, output: str) -> None:
do_delinearize_llama4(model, output)
@cli.command()
def wizard():
from axolotl.cli.wizard import do_wizard
do_wizard()
cli.add_command(lm_eval)

View File

@@ -1,429 +0,0 @@
"""Wizard for creating yaml configs."""
import click
import torch
import yaml
from packaging import version
from transformers.training_args import OptimizerNames
from axolotl.cli.art import print_axolotl_text_art
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
def do_wizard():
print_axolotl_text_art()
# Ask where to save the config
cfg = DictDefault({})
config_path = click.prompt(
"Where do you want to save the config?", type=str, default="config.yaml"
)
# Ask base model
base_model = click.prompt("What base model do you want to use?", type=str)
cfg["base_model"] = base_model.strip()
# Ask whether want to enable Vision model
# TODO: check if model has vision layers instead of asking user
train_vision_model = click.confirm(
"If this model has vision layers, do you want to train them?", default=False
)
if train_vision_model:
cfg["processor_type"] = "AutoProcessor"
cfg["skip_prepare_dataset"] = True
cfg["remove_unused_columns"] = False
cfg["sample_packing"] = False
# Ask whether they want to set any advanced model features (custom tokenizer, custom config, etc)
advanced_model_features = click.confirm(
"Do you want to set any advanced model features? (custom tokenizer, custom config, remote code etc)",
default=False,
)
if advanced_model_features:
# Ask whether they want to use a custom config
base_model_config = click.prompt(
"What model config do you want to use? (leave blank for default)",
type=str,
default="",
)
if base_model_config:
cfg["base_model_config"] = base_model_config
# Ask whether they want to use a specific revision of the model
revision_of_model = click.prompt(
"What revision of the model do you want to use? (leave blank for default)",
type=str,
default="",
)
if revision_of_model:
cfg["revision_of_model"] = revision_of_model
# Ask whether they want to use a custom tokenizer
tokenizer_config = click.prompt(
"What tokenizer do you want to use? (leave blank for default)",
type=str,
default="",
)
if tokenizer_config:
cfg["tokenizer_config"] = tokenizer_config
# Ask whether they want to use remote code
trust_remote_code = click.confirm(
"Do you want to use remote code?", default=False
)
if trust_remote_code:
cfg["trust_remote_code"] = trust_remote_code
# Whether to resize token embeddings
resize_token_embeddings_to_32x = click.confirm(
"Do you want to resize token embeddings to 32x?", default=False
)
if resize_token_embeddings_to_32x:
cfg["resize_token_embeddings_to_32x"] = resize_token_embeddings_to_32x
# Whether to shrink embeddings to len(tokenizer)
shrink_embeddings = click.confirm(
"Do you want to shrink embeddings to len(tokenizer)?", default=False
)
if shrink_embeddings:
cfg["shrink_embeddings"] = shrink_embeddings
# Whether to skip upcast embeddings
embeddings_skip_upcast = click.confirm(
"Do you want to skip upcast embeddings?", default=False
)
if embeddings_skip_upcast:
cfg["embeddings_skip_upcast"] = embeddings_skip_upcast
# Whether to random init weights
random_init_weights = click.confirm(
"Do you want to random init weights?", default=False
)
if random_init_weights:
cfg["random_init_weights"] = random_init_weights
# Get model type
config = load_model_config(cfg)
model_type = config.model_type
# Ask sequence length
sequence_length = click.prompt("What sequence length do you want to use?", type=int)
cfg["sequence_length"] = sequence_length
# Whether to turn on sample packing
if cfg["sample_packing"] is None:
cfg["sample_packing"] = click.confirm(
"Do you want to turn on sample packing? This will speed up training by packing multiple samples into a single batch.",
default=True,
)
if cfg["sample_packing"]:
cfg["pad_to_sequence_len"] = True
# Whether to turn off eval sample packing
no_eval_sample_packing = click.confirm(
"Do you want to turn off eval sample packing? This will slow down evaluation but is recommended if you are using a small validation set.",
default=False,
)
if no_eval_sample_packing:
cfg["eval_sample_packing"] = False
# Hardware check
try:
is_ampere_or_newer = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere_or_newer = False
except AssertionError: # this is raised if no cuda is available
is_ampere_or_newer = False
# Get num gpus
try:
num_gpus = torch.cuda.device_count()
except RuntimeError:
num_gpus = 0
# Get torch version
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
is_torch_2_6_or_newer = version.parse(torch_version) >= version.parse("2.6.0")
# Whether to turn on attention
opt = ["xformers", "sdp"]
if is_ampere_or_newer:
opt.append("flash")
if is_torch_2_6_or_newer:
opt.append("flex")
if cfg["sample_packing"]:
if "flash" in opt:
default_opt = "flash"
elif "flex" in opt:
default_opt = "flex"
else:
default_opt = opt[0]
attention = click.prompt(
"Which attention backend do you want to use? Sample packing requires an attention backend to be set.",
type=click.Choice(opt),
default=default_opt,
)
else:
# non-sample packing supports no attention and S2
opt.extend(["none", "s2"])
attention = click.prompt(
"Which attention backend do you want to use?",
type=click.Choice(opt),
default="none",
)
if attention == "none":
attention = None
# TODO: if xformers, check if FA is installed
# TODO: flex doc mentioned requiring seq len to be divisible by 128. Unclear if limitation still exists
# TODO: requires #2489
cfg["attention"] = attention
# Whether to turn on gradient checkpointing
# TODO: need to wait for offload_disk PR to be merged
gradient_checkpointing = click.prompt(
"Which gradient checkpointing strategy do you want to use?",
type=click.Choice(["none", "true", "offload", "offload_disk"]),
default="true",
)
if gradient_checkpointing == "none":
gradient_checkpointing = False
elif gradient_checkpointing == "true":
gradient_checkpointing = True
# Ask whether to set use_reentrant
# TODO: get correct defaults based on SFT/RL mode and single/multigpu
# use_reentrant = click.confirm(
# "Do you want to set use_reentrant?",
# default=True,
# )
# if use_reentrant:
# cfg["use_reentrant"] = use_reentrant
# Optimizer
cfg["optimizer"] = click.prompt(
"Which optimizer do you want to use?",
type=click.Choice((OptimizerNames | CustomSupportedOptimizers)),
default=OptimizerNames.ADAMW_TORCH_FUSED,
)
cfg["lr_scheduler"] = click.prompt(
"Which learning rate scheduler do you want to use?",
type=click.Choice(
[
"cosine",
"one_cycle",
"rex",
"log_sweep",
"linear",
"cosine_with_restarts",
"polynomial",
"constant",
"constant_with_warmup",
"inverse_sqrt",
"reduce_lr_on_plateau",
"cosine_with_min_lr",
"warmup_stable_decay",
]
),
default="cosine",
)
# Plugins
cfg["plugins"] = []
# Whether to turn on cut cross entropy
if is_ampere_or_newer:
# Note: This may error if users don't have CCE installed
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
CUT_CROSS_ENTROPY_MODEL_MAPPING,
)
if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING:
cut_cross_entropy = click.confirm(
"Do you want to turn on cut cross entropy? This will save VRAM if the model has a large vocab size.",
default=True,
)
if cut_cross_entropy:
cfg["plugins"].append(
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin"
)
cfg["cut_cross_entropy"] = True
use_liger_kernel = click.confirm(
"Do you want to use the liger kernel? This will speed up training and save VRAM.",
default=True,
)
if use_liger_kernel:
cfg["plugins"].append("axolotl.integrations.liger.LigerPlugin")
cfg["liger_rope"] = click.confirm(
"Do you want to enable liger rope?",
default=True,
)
cfg["liger_rms_norm"] = click.confirm(
"Do you want to enable liger rms norm?",
default=True,
)
cfg["liger_glu_activation"] = click.confirm(
"Do you want to enable liger glu activation?",
default=True,
)
cfg["liger_layer_norm"] = click.confirm(
"Do you want to enable liger layer norm?",
default=True,
)
if cfg["cut_cross_entropy"] is not True:
cfg["liger_fused_linear_cross_entropy"] = click.confirm(
"Do you want to enable liger fused linear cross entropy?",
default=True,
)
# TODO: lora kernels (but they auto enable via validator already)
# TODO: is there incompat between torch compile and liger?
cfg["torch_compile"] = click.confirm(
"Do you want to enable torch compile?",
default=True,
)
# Multi-gpu
if num_gpus > 1:
# Ask whether to use DDP/Deepspeed/FSDP
multi_gpu_mode = click.prompt(
"Which multi-gpu mode do you want to use?",
type=click.Choice(["ddp", "deepspeed", "fsdp"]),
default="ddp",
)
if multi_gpu_mode == "deepspeed":
# Ask which deepspeed config to use
cfg["deepspeed"] = click.prompt(
"Which deepspeed config do you want to use? The higher the number, the more VRAM you will save, but the slower it will run.",
type=click.Choice(
[
"zero1.json",
"zero1_torch_compile.json",
"zero2.json",
"zero3.json",
"zero3_bf16.json",
"zero3_bf16_cpuoffload_all.json",
"zero3_bf16_cpuoffload_params.json",
]
),
default="zero1.json",
)
elif multi_gpu_mode == "fsdp":
fsdp_version = click.prompt(
"Which fsdp version do you want to use?",
type=click.Choice([1, 2]),
default=2,
)
# TODO: Handle FSDP config
if fsdp_version == 1:
cfg["fsdp"] = ["full_shard", "auto_wrap"]
# Ask which state dict type to use
fsdp_state_dict_type = click.prompt(
"Which fsdp state dict type do you want to use?",
type=click.Choice(["FULL_STATE_DICT", "SHARDED_STATE_DICT"]),
default="FULL_STATE_DICT",
)
fsdp_offload_params = click.confirm(
"Do you want to offload parameters?",
default=True,
)
# TODO: can we load the model class and auto pull a default for this?
fsdp_transformer_layer_cls_to_wrap = click.prompt(
"Which transformer layer class to wrap? It is usually the Decoder layer class.",
type=str,
)
# TODO: add other options
cfg["fsdp_config"] = {
"state_dict_type": fsdp_state_dict_type,
"offload_params": fsdp_offload_params,
"transformer_layer_cls_to_wrap": fsdp_transformer_layer_cls_to_wrap,
}
elif fsdp_version == 2:
raise NotImplementedError()
# Training mode (sft or rl)
training_mode = click.prompt(
"Which training mode do you want to use?",
type=click.Choice(["sft", "rl"]),
default="sft",
)
if training_mode == "rl":
cfg["rl"] = click.prompt(
"Which rl mode do you want to use?",
type=click.Choice(["dpo", "ipo", "orpo", "kto", "grpo", "simpo"]),
)
# TODO: handle RL options
# Whether to use adapter
# Get batch/grad accu
# Get learning rate
# Get weight decay
# Get max grad norm
# Get num train epochs
# Get warmup ratio
# Get save ratio
# Get eval ratio
# Get dataset config
# Load metric tracker
# Save config to yaml
# TODO: improve output yaml formatting. Need to add comments to help separate sections
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(cfg.to_dict(), f, sort_keys=False)

View File

@@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_beta3:
training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.adam_epsilon2:
training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
@@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
@@ -1170,7 +1174,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from contextlib import nullcontext
from typing import Any
import datasets
@@ -14,7 +13,7 @@ from accelerate.utils import (
broadcast_object_list,
gather,
gather_object,
is_peft_model,
is_peft_available,
)
from datasets import Dataset, IterableDataset
from torch import nn
@@ -30,15 +29,13 @@ from transformers import (
TrainerCallback,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available
from trl import GRPOTrainer
from trl.data_utils import (
apply_chat_template,
is_conversational,
maybe_apply_chat_template,
)
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.import_utils import is_deepspeed_available
from trl.extras.profiling import profiling_context
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
@@ -52,62 +49,12 @@ if is_peft_available():
# pylint: disable=unused-import
from peft import PeftConfig
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
@profiling_decorator
def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = (
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
)
if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
# adapters in a sharded manner is not supported.
with gather_if_zero3(list(self.model.parameters())):
self.model.merge_adapter()
# Update vLLM weights while parameters are gathered
for name, param in self.model.named_parameters():
# When using PEFT, we need to recover the original parameter name and discard some parameters
name = (
name.removeprefix("base_model.model.")
.removeprefix("base_model.model.")
.replace(".base_layer", "")
)
if self.model.prefix in name:
continue
# When module to save, remove its prefix and discard the original module
if "original_module" in name:
continue
name = name.replace("modules_to_save.default.", "")
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Unmerge adapters while parameters are still gathered
self.model.unmerge_adapter()
# Parameters will automatically be repartitioned when exiting the context
else:
# For non-PEFT models, simply gather and update each parameter individually.
for name, param in self.model.named_parameters():
with gather_if_zero3([param]):
if self.accelerator.is_main_process:
self.vllm_client.update_named_param(name, param.data)
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -227,6 +227,19 @@ class AxolotlTrainingMixins:
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(

View File

@@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import (
_CONFIG_FOR_DOC,
COHERE_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import (
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
_CONFIG_FOR_DOC,
GEMMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -20,15 +20,11 @@ from torch import nn
from transformers.cache_utils import Cache, HybridCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -170,10 +162,6 @@ def cce_forward(
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -19,15 +19,9 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast,
)
from transformers.models.llama.modeling_llama import (
_CONFIG_FOR_DOC,
LLAMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -16,22 +16,12 @@ from torch import nn
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama4.modeling_llama4 import (
_CONFIG_FOR_DOC,
LLAMA4_INPUTS_DOCSTRING,
Llama4CausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
@@ -160,9 +150,6 @@ def cce_forward(
)
@replace_return_docstrings(
output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: torch.LongTensor | None = None, # type: ignore

View File

@@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import (
Mistral3CausalLMOutputWithPast,
)
from transformers.models.mistral.modeling_mistral import (
_CONFIG_FOR_DOC,
MISTRAL_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,

View File

@@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import (
apply_lce,
)
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
_CONFIG_FOR_DOC,
QWEN2MOE_INPUTS_DOCSTRING,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import (
)
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
_CONFIG_FOR_DOC,
QWEN2_VL_INPUTS_DOCSTRING,
Qwen2VLCausalLMOutputWithPast,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
_PATCH_OPTS: PatchOptions | None = None
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward_multimodal(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import (
TransformersModelT,
apply_lce,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
_CONFIG_FOR_DOC,
QWEN3_MOE_INPUTS_DOCSTRING,
KwargsForCausalLM,
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
load_balancing_loss_func,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import can_return_tuple
@@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None
@can_return_tuple
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,

View File

@@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
# @replace_return_docstrings(
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
# )
def lce_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import (
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.jamba.modeling_jamba import (
_CONFIG_FOR_DOC,
JAMBA_INPUTS_DOCSTRING,
HybridMambaAttentionDynamicCache,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache
from transformers.models.gemma3.modeling_gemma3 import (
_CONFIG_FOR_DOC,
GEMMA3_INPUTS_DOCSTRING,
Gemma3CausalLMOutputWithPast,
logger,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def new_forward(
self,
input_ids: torch.LongTensor = None,

View File

@@ -1,6 +1,7 @@
"""MLFlow module for trainer callbacks"""
import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
@@ -16,6 +17,11 @@ if TYPE_CHECKING:
LOG = logging.getLogger("axolotl.callbacks")
def should_log_artifacts() -> bool:
truths = ["TRUE", "1", "YES"]
return os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in truths
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""
@@ -32,13 +38,18 @@ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
if should_log_artifacts():
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
else:
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
"Skipping logging artifacts to MLflow (hf_mlflow_log_artifacts is false)"
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")

View File

@@ -72,6 +72,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
num_proc=cfg.dataset_processes,
**map_kwargs,
)

View File

@@ -484,7 +484,7 @@ def get_dataset_wrapper(
}
LOG.info(
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
f"Loading dataset: {config_dataset['path']} with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
)
if (

View File

@@ -5,8 +5,11 @@ from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
from axolotl.utils.gradient_checkpointing.offload_cpu import (
CPU_Offloaded_Gradient_Checkpointer,
)
from axolotl.utils.gradient_checkpointing.offload_disk import (
Disco,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
@@ -26,12 +29,31 @@ def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
return CPU_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)
else decoder_layer.__self__
),
*args,
)
def hf_grad_checkpoint_disk_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Disco.apply(
decoder_layer,
*args,
)
return Disco.apply(
(
decoder_layer.func.__self__
if isinstance(decoder_layer, partial)

View File

@@ -1,4 +1,4 @@
"""Unsloth checkpointing"""
"""CPU offloaded checkpointing"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
@@ -26,7 +26,7 @@ else:
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
torch.autograd.Function
):
"""

View File

@@ -0,0 +1,531 @@
"""
DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
"""
# Copyright 2025 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import concurrent.futures
import logging
import os
import queue
import shutil
import tempfile
import threading
import time
import uuid
from collections import deque
from concurrent.futures import Future
from typing import Dict
import torch
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger
logger = logging.getLogger(__name__)
class DiskOffloadManager:
"""
Manages offloaded tensors and handles prefetching in a separate thread.
Includes synchronization to prevent race conditions.
"""
def __init__(
self,
prefetch_size: int = 3,
prefetch_to_gpu: bool = True,
save_workers: int = 4,
):
"""
Args:
prefetch_size: Maximum number of tensors to prefetch in the background.
prefetch_to_gpu: Whether to prefetch tensors directly to GPU memory.
save_workers: Maximum number of concurrent save operations.
"""
self.temp_dir = tempfile.mkdtemp(prefix="disco_")
# Track tensor paths and their status
self.tensor_paths: deque = deque() # Ordered history of tensor paths (LIFO)
self.file_locks: Dict[str, threading.Lock] = (
{}
) # Maps file_path -> threading.Lock()
# Maps file_path -> status ("saving", "ready", "prefetching", "loaded", "deleted")
self.file_status: Dict[str, str] = {}
self.max_prefetch = prefetch_size
self.prefetch_to_gpu = prefetch_to_gpu
# Thread synchronization
self.manager_lock = threading.RLock() # Used for thread-safe operations
# Prefetch queue and cache
self.prefetch_queue: queue.Queue = queue.Queue()
self.prefetch_cache: Dict[str, torch.Tensor] = {} # Maps file_path -> tensor
# Save queue and thread pool
self.save_queue: queue.Queue = queue.Queue()
self.save_pool = concurrent.futures.ThreadPoolExecutor(max_workers=save_workers)
self.save_futures: Dict[str, Future] = {}
self.save_semaphore = threading.Semaphore(
save_workers * 2
) # Limit concurrent save operations
# Start prefetch worker thread
self.stop_event = threading.Event()
# start multiple threads for prefetching
self.prefetch_worker_count = 2
self.prefetch_workers = []
for _ in range(self.prefetch_worker_count):
worker = threading.Thread(target=self._prefetch_worker, daemon=True)
worker.start()
self.prefetch_workers.append(worker)
# Start save worker thread
self.save_worker = threading.Thread(target=self._save_worker, daemon=True)
self.save_worker.start()
self.idx = 0
atexit.register(self.cleanup)
def _save_worker(self):
"""Background thread that processes the save queue"""
while not self.stop_event.is_set():
try:
save_item = self.save_queue.get(timeout=0.5)
if save_item is None:
continue
tensor, file_path = save_item
# Submit the save task to the thread pool
future = self.save_pool.submit(
self._save_tensor_to_disk, tensor, file_path
)
with self.manager_lock:
self.save_futures[file_path] = future
self.save_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def _save_tensor_to_disk(self, tensor: torch.Tensor, file_path: str):
"""Actually save the tensor to disk"""
try:
# Save tensor to disk
cpu_tensor = tensor.detach().cpu()
torch.save(cpu_tensor, file_path)
del cpu_tensor
with self.manager_lock:
# Mark file as ready
self.file_status[file_path] = "ready"
# Release semaphore
self.save_semaphore.release()
return True
except FileNotFoundError as e:
logger.error(f"Error saving tensor to {file_path}: {e}")
with self.manager_lock:
self.file_status[file_path] = "error"
# Release semaphore
self.save_semaphore.release()
return False
def _prefetch_worker(self):
"""Background thread that loads tensors from disk ahead of time"""
while not self.stop_event.is_set():
try:
file_path = self.prefetch_queue.get(timeout=0.5)
if file_path is None:
continue
# Check if file is available and not already in cache
with self.manager_lock:
if (
file_path not in self.file_status
or self.file_status[file_path] == "deleted"
):
self.prefetch_queue.task_done()
if file_path in self.prefetch_cache:
self.prefetch_queue.task_done()
continue
# If file is still being saved, wait for it
if (
self.file_status[file_path] == "saving"
and file_path in self.save_futures
):
# Re-queue this prefetch request with a little delay
self.prefetch_queue.task_done()
time.sleep(0.1)
self.prefetch_queue.put(file_path)
continue
# Mark file as being prefetched
self.file_status[file_path] = "prefetching"
# Load tensor from disk and store in cache
try:
if os.path.exists(file_path):
if self.prefetch_to_gpu:
tensor = torch.load(
file_path,
map_location=torch.device("cuda"),
weights_only=True,
)
else:
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.prefetch_cache[file_path] = tensor
self.file_status[file_path] = "ready"
else:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(
f"Prefetch error: File not found {file_path}"
)
self.file_status[file_path] = "missing"
except FileNotFoundError as e:
with self.manager_lock:
if self.file_status.get(file_path) != "deleted":
logger.warning(f"Prefetch error for {file_path}: {e}")
self.file_status[file_path] = "error"
self.prefetch_queue.task_done()
except queue.Empty:
time.sleep(0.01) # Small sleep to prevent CPU spinning
continue
def save_tensor(self, tensor: torch.Tensor):
"""Save tensor to disk asynchronously and return file path with thread-safe operations"""
# Generate unique file path
self.idx += 1
file_path: str = os.path.join(
self.temp_dir, f"{self.idx:06d}-{uuid.uuid4()}.pt"
)
with self.manager_lock:
# Mark file as being saved
self.file_locks[file_path] = threading.Lock()
self.file_status[file_path] = "saving"
# Add to history
self.tensor_paths.append(file_path)
# Acquire semaphore to limit concurrent save operations
self.save_semaphore.acquire() # pylint: disable=consider-using-with
# Queue tensor for saving in background
self.save_queue.put((tensor.detach(), file_path))
return file_path
def wait_for_save(self, file_path, timeout=None) -> None:
"""Wait for a tensor to be saved to disk"""
start_time = time.time()
while timeout is None or time.time() - start_time < timeout:
with self.manager_lock:
if self.file_status.get(file_path) == "ready":
return
if self.file_status.get(file_path) in ["error", "missing", "deleted"]:
return
if file_path in self.save_futures:
future = self.save_futures[file_path]
if future.done():
return
# Small sleep to prevent CPU spinning
time.sleep(0.01)
# Timeout
logger.warning(f"Timeout waiting for tensor to be saved: {file_path}")
return
def load_tensor(self, file_path, target_device="cuda"):
"""Load tensor from disk or prefetch cache with proper synchronization"""
# Wait for tensor to be saved if it's still in progress
self.wait_for_save(file_path)
tensor = None
# Try to get from cache first
with self.manager_lock:
# Check if tensor is already in cache
if file_path in self.prefetch_cache:
tensor = self.prefetch_cache[file_path]
del self.prefetch_cache[file_path]
self.file_status[file_path] = "loaded"
if tensor is not None:
# Ensure tensor is on correct device
if target_device != "cpu" and tensor.device.type == "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
# If not in cache, load directly from disk
try:
if not os.path.exists(file_path):
logger.error(f"File not found for loading: {file_path}")
raise FileNotFoundError(f"File not found: {file_path}")
tensor = torch.load(file_path, weights_only=True)
with self.manager_lock:
self.file_status[file_path] = "loaded"
if target_device != "cpu":
tensor = tensor.to(target_device, non_blocking=True)
return tensor
except Exception as e:
logger.error(f"Error loading tensor from {file_path}: {e}")
raise
def _safe_delete_file(self, file_path):
"""Safely delete a file with proper synchronization"""
with self.manager_lock:
# Make sure any save operation is completed
if file_path in self.save_futures:
future = self.save_futures[file_path]
try:
if not future.done():
future.cancel()
del self.save_futures[file_path]
except FileNotFoundError as e:
logger.warning(
f"Error canceling save operation for {file_path}: {e}"
)
# Only delete if file exists and is not being prefetched
status = self.file_status.get(file_path)
if status in ["ready", "loaded", "error", "missing"]:
try:
if os.path.exists(file_path):
os.remove(file_path)
self.file_status[file_path] = "deleted"
return True
except FileNotFoundError as e:
logger.warning(f"Error deleting file {file_path}: {e}")
return False
def trigger_prefetch(self, n=None):
"""Trigger prefetching of the next N tensors with proper synchronization"""
if n is None:
n = self.max_prefetch
prefetch_paths = []
with self.manager_lock:
# Find files that are ready to be prefetched (not already in cache or being prefetched)
for path in reversed(self.tensor_paths):
if (
path not in self.prefetch_cache
and self.file_status.get(path) == "ready"
):
prefetch_paths.append(path)
if len(prefetch_paths) >= n:
break
# Queue files for prefetching
for path in prefetch_paths:
self.prefetch_queue.put(path)
def cleanup_tensor(self, file_path: str):
"""Clean up a specific tensor file after it's been used"""
with self.manager_lock:
if file_path in self.tensor_paths:
self.tensor_paths.remove(file_path)
# Remove from prefetch cache if present
if file_path in self.prefetch_cache:
del self.prefetch_cache[file_path]
# Remove from save futures if present
if file_path in self.save_futures:
future = self.save_futures[file_path]
if not future.done():
future.cancel()
del self.save_futures[file_path]
# Try to delete the file
self._safe_delete_file(file_path)
def cleanup(self):
"""Clean up all temp files and stop prefetch thread with proper synchronization"""
self.stop_event.set()
# Cancel all pending save operations
with self.manager_lock:
for _, future in self.save_futures.items():
if not future.done():
future.cancel()
self.save_futures.clear()
# Drain the save queue
while not self.save_queue.empty():
try:
self.save_queue.get_nowait()
self.save_queue.task_done()
except queue.Empty:
break
# Shutdown the save pool
self.save_pool.shutdown(wait=False)
# Join the save worker thread
if self.save_worker.is_alive():
self.save_worker.join(timeout=2.0)
# Join the prefetch worker threads
for thread in self.prefetch_workers:
if thread.is_alive():
thread.join(timeout=2.0)
# Clear cache and remove all temporary files
with self.manager_lock:
self.prefetch_cache.clear()
paths_to_delete = list(self.tensor_paths)
self.tensor_paths.clear()
# Delete all temporary files
for path in paths_to_delete:
self._safe_delete_file(path)
# Remove temp directory
try:
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir, ignore_errors=True)
except FileNotFoundError as e:
logger.warning(f"Error removing temporary directory {self.temp_dir}: {e}")
class Disco(torch.autograd.Function):
"""
Disco: DIsk-based Storage and Checkpointing with Optimized prefetching
Advanced disk-based gradient checkpointer with prefetching.
"""
# Shared manager instance across all checkpointing operations
_manager = None
@staticmethod
def get_instance(prefetch_size=1, prefetch_to_gpu=True, save_workers=4):
"""Get or create the offload manager"""
if Disco._manager is None:
Disco._manager = DiskOffloadManager(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
return Disco._manager
@staticmethod
@torch_cuda_amp_custom_fwd
def forward(
ctx,
forward_function,
hidden_states,
*args,
prefetch_size=1,
prefetch_to_gpu=True,
save_workers=4,
):
"""Forward pass that offloads activations to disk asynchronously"""
# Get or create the manager
manager = Disco.get_instance(
prefetch_size=prefetch_size,
prefetch_to_gpu=prefetch_to_gpu,
save_workers=save_workers,
)
# Save tensor to disk asynchronously
file_path = manager.save_tensor(hidden_states)
# Run forward pass immediately without waiting for save to complete
with torch.no_grad():
output = forward_function(hidden_states, *args)
# Store what we need for backward
ctx.save_for_backward(torch.tensor([0])) # Dummy tensor
ctx.file_path = file_path
ctx.forward_function = forward_function
ctx.args = args
return output
@staticmethod
@torch_cuda_amp_custom_bwd
def backward(ctx, *grad_outputs):
"""Backward pass that loads activations from disk with prefetching"""
# Get the manager
manager = Disco._manager
# Trigger prefetching for future tensors
# This happens at the start of backward, so should have time to complete
manager.trigger_prefetch()
# Load hidden states from disk or prefetch cache
file_path = ctx.file_path
try:
# Ensure the file is saved before we try to load it
manager.wait_for_save(file_path)
hidden_states = manager.load_tensor(file_path)
hidden_states.requires_grad = True
# Compute gradients
with torch.enable_grad():
output = ctx.forward_function(hidden_states, *ctx.args)
# Handle tuple outputs properly
if isinstance(output, tuple):
if len(grad_outputs) == len(output):
torch.autograd.backward(output, grad_outputs)
else:
torch.autograd.backward(output, grad_outputs[0])
else:
torch.autograd.backward(output, grad_outputs[0])
# Clean up the file after we're done with it
manager.cleanup_tensor(file_path)
return (
(
None, # forward_function
hidden_states.grad, # hidden_states grad
)
+ (None,) * len(ctx.args) # for each arg
+ (
None, # prefetch_size
None, # prefetch_to_gpu
None, # save_workers
)
)
except Exception as e:
logger.error(f"Error in backward pass: {e}")
# Clean up the file even on error
manager.cleanup_tensor(file_path)
raise

View File

@@ -70,7 +70,10 @@ from axolotl.utils.distributed import (
is_local_main_process,
is_main_process,
)
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_offload_wrapper
from axolotl.utils.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper,
hf_grad_checkpoint_offload_wrapper,
)
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
@@ -620,8 +623,55 @@ class ModelLoader:
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
if self.cfg.gradient_checkpointing == "offload_disk":
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_disk_offload_wrapper
)
if self.cfg.flash_attention:
use_fa3 = False
if self.cfg.use_flash_attention_3 is True:
use_fa3 = True
elif self.cfg.use_flash_attention_3 == "auto":
if torch.cuda.get_device_capability() >= (9, 0):
# FA3 is only available on Hopper GPUs and newer
use_fa3 = True
if not importlib.util.find_spec("flash_attn_interface"):
use_fa3 = False
if use_fa3 and not importlib.util.find_spec("flash_attn_interface"):
# this can happen when use_flash_attention_3 is explicity set to True
# and flash_attn_interface is not installed
raise ModuleNotFoundError(
"Please install the flash_attn_interface library to use Flash Attention 3.x"
)
if use_fa3 and importlib.util.find_spec("flash_attn_interface") is not None:
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
def flash_attn_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_func_v3(*args, **kwargs)[0]
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
if "softmax_scale" in kwargs and len(args) >= 4:
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
args = (*args[:3],) + args[4:]
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3_wrapper
)
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
flash_attn_varlen_func_v3_wrapper
)
LOG.info("Switched to Flash Attention v3")
self.patch_attention()
if self.cfg.sample_packing and self.cfg.s2_attention:
@@ -692,6 +742,7 @@ class ModelLoader:
patch_mllama()
# TODO deprecate soon
if self.model_config.model_type == "btlm":
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
@@ -699,6 +750,7 @@ class ModelLoader:
replace_btlm_attn_with_flash_attn(self.cfg.base_model)
# TODO deprecate soon
if (
self.model_config.model_type == "stablelm_epoch"
and self.cfg.sample_packing

View File

@@ -83,7 +83,6 @@ class AxolotlInputConfig(
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None
embeddings_skip_upcast: bool | None = None
random_init_weights: bool | None = None
rl: RLType | None = None
trl: TRLConfig | None = Field(
@@ -179,7 +178,7 @@ class AxolotlInputConfig(
# torch_dtype: torch.dtype | None
gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field(
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
default=False
)
gradient_checkpointing_kwargs: dict[str, Any] | None = None
@@ -234,6 +233,7 @@ class AxolotlInputConfig(
flash_attn_fuse_qkv: bool | None = None
flash_attn_fuse_mlp: bool | None = None
flash_optimum: bool | None = None
use_flash_attention_3: Literal["auto"] | bool | None = None
eager_attention: bool | None = None

View File

@@ -421,6 +421,7 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
import transformers.modeling_flash_attention_utils
from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
@@ -434,6 +435,19 @@ def cleanup_monkeypatches():
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
original_fa_func = None
original_fa_varlen_func = None
if (
importlib.util.find_spec("flash_attn")
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
and hasattr(
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
)
):
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
original_fa_varlen_func = (
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
)
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
@@ -444,6 +458,11 @@ def cleanup_monkeypatches():
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
if original_fa_func:
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
original_fa_varlen_func
)
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
@@ -458,6 +477,7 @@ def cleanup_monkeypatches():
("transformers.trainer",),
("transformers", ["Trainer"]),
("transformers.loss.loss_utils",),
("transformers.modeling_flash_attention_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]

View File

@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"""
)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
finally:
recursive_kill(vllm_process)
@pytest.mark.skip(reason="flaky test")
@pytest.mark.parametrize(
"num_gpus",
[1, 2],
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
"VLLM_DISABLE_COMPILE_CACHE": "1",
# "VLLM_USE_V1": "0",
}
vllm_process = start_vllm(
cfg.base_model,

View File

@@ -101,7 +101,13 @@ class TestMultiGPULlama:
"gradient_accumulation_steps",
[1, 2],
)
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
@pytest.mark.parametrize(
"use_flash_attention_3",
[False, "auto"],
)
def test_lora_ddp_packed(
self, temp_dir, gradient_accumulation_steps, use_flash_attention_3
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
@@ -138,6 +144,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"use_flash_attention_3": use_flash_attention_3,
}
)

View File

@@ -26,10 +26,15 @@ class TestActivationCheckpointing:
E2E tests for activation checkpointing
"""
@pytest.mark.parametrize(
"gradient_checkpointing",
["offload", "offload_disk"],
)
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
gradient_checkpointing,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -64,7 +69,7 @@ class TestActivationCheckpointing:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": "offload",
"gradient_checkpointing": gradient_checkpointing,
}
)

View File

@@ -4,7 +4,6 @@ E2E tests for packed training
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
@@ -14,18 +13,17 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
from .utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase):
class TestPackedLlama:
"""
Test case for Packed training of llama models
"""
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(