Compare commits

..

18 Commits

Author SHA1 Message Date
Dan Saunders
6f6d917a99 Revert "checkpoint model on first step callback (#2906)"
This reverts commit 10ba1622f7.
2025-07-15 15:01:12 -04:00
Dan Saunders
10ba1622f7 checkpoint model on first step callback (#2906)
* checkpoint model on first step callback

* remove debug

* add test cases; update existing tests not to save on first step

* move test out of solo

* delete

* default to False

* typo
2025-07-15 15:00:48 -04:00
Wing Lian
d320ef6199 fix for upstream refactor of KwargsForCausalLM (#2911) 2025-07-15 11:28:41 -04:00
NanoCode012
354eaaf0d3 feat: add call method to mistral tokenizer wrapper (#2898) 2025-07-14 22:33:35 -04:00
greenhestu
a061446540 Fix: Prevents merging of tool arguments during preprocessing (#2909) 2025-07-14 22:33:10 -04:00
Wing Lian
cd079b5536 Tensor parallel w DeepSpeed AutoTP (#2574)
* support for deepspeed autotup

* bump to latest deepspeed that supports deepcompile too

* add deepcompile support too

* fix total steps calculation for TP

* setup fixture for tp

* update ds config to ensure weights are gathered for checkpoint

* fix duplicate validation names

* chore: lint
2025-07-14 21:33:48 -04:00
Wing Lian
5cc16040a8 move the plugin post trainer create to the setup trainer (#2907)
* move the plugin post trainer create to the setup trainer

* move post-train plugins to execute-training fn
2025-07-14 20:11:33 -04:00
Wing Lian
38359a8997 allow profiling in mid-training rather from the start (#2899) [skip ci]
* allow profiling in mid-training rather from the start

* simplify based on PR feedback

* fix logic, improve saving at end, add tests
2025-07-14 20:11:11 -04:00
Wing Lian
7dc3ac6cb3 update nightlies builds (#2921) [skip ci] 2025-07-14 20:10:43 -04:00
Wing Lian
99187cd208 Activation Offloading w CUDA Streams (#2900) [skip ci]
* use cuda streams for activation offloading

* use torch native ops

* update cfg schema for streams

* fix literal constructor for set

* use context for training step so it doesn't affect evals

* disable streams

* auto gc on eval steps

* use activation_offloading config arg

* add docs for gradient checkpointing

* handle validation for gc/ao

* use cuda streams for act offloading

* add more validation for AC w/o GC

* fix docs

* move activation_offloading lower in definition so it doesn't break args/kwargs

* fix kd due to import order
2025-07-14 20:10:20 -04:00
Wing Lian
aa684122f1 upgrade peft==0.16.0 and datasets==4.0.0 (#2917) [skip ci]
* upgrade peft to 0.16.0

* upgrade datasets to 4.0.0

* refactor dupes from merge/rebase

* fix check for fsdp1 + sharded_state_dict

* use full state dict for ci
2025-07-14 20:09:26 -04:00
Wing Lian
ca4d4ef793 don't init distributed for deepspeed if preprocessing (#2920)
* don't init distributed for deepspeed if preprocessing

* add e2e test to validate preprocess cli with deepspeed

* ignore duplicate code for cfg
2025-07-14 14:19:19 -04:00
Dan Saunders
37edbe4999 Remove extra torch.compile call (#2904)
* debug

* debug

* debug

* moving validation code to transformers

* revert unneeded change

* add accelerator config to base trainer builder

* add back accumulated_cache_size_limit setting

* lint
2025-07-14 12:32:45 -04:00
Wing Lian
e581c15d40 refactor dupes from merge/rebase (#2919) [skip ci] 2025-07-14 10:05:26 -04:00
Wing Lian
af92151a7b FSDP2 fix validation and add tests (#2910)
* fix validation and add tests

* remove debugging and add more tests

* remove migrate_fsdp
2025-07-14 09:25:44 -04:00
Wing Lian
80dc4c261a fix xformers version for python 2.6 (#2916) [skip ci] 2025-07-14 09:24:29 -04:00
Wing Lian
7ccbbd8e77 upgrade liger to 0.6.0 (#2893) [skip ci] 2025-07-14 09:24:07 -04:00
Wing Lian
5081db7f8a upgrade trl==0.19.1 (#2892) [skip ci]
* upgrade trl==0.19.1

* add vllm for tests for grpo

* fixes to work with latest trl

* need data_parallel_size config too

* support for vllm_mode for server / colocate

* vllm settings for colocate

* relax vllm version

* bump min hf hub for latest vllm support

* add hints on string literal for vllm mode

* use latest transformers 4.53.2

* tweak acceptable loss on flaky test_ds_zero3_packed test

* don't run flaky vllm/grpo tests for now
2025-07-14 09:23:42 -04:00
44 changed files with 1176 additions and 419 deletions

View File

@@ -33,6 +33,13 @@ jobs:
axolotl_extras: axolotl_extras:
num_gpus: 2 num_gpus: 2
nightly_build: "true" nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"

View File

@@ -12,11 +12,16 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -60,15 +65,15 @@ jobs:
strategy: strategy:
matrix: matrix:
include: include:
- cuda: 124 - cuda: 126
cuda_version: 12.4.1 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
- cuda: 126 - cuda: 126
cuda_version: 12.6.3 cuda_version: 12.6.3
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.7.1
axolotl_extras: axolotl_extras:
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:

View File

@@ -276,6 +276,7 @@ website:
- docs/torchao.qmd - docs/torchao.qmd
- docs/custom_integrations.qmd - docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd - docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- section: "Troubleshooting" - section: "Troubleshooting"
contents: contents:

View File

@@ -0,0 +1,29 @@
---
title: Gradient Checkpointing and Activation Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
models by reducing the memory footprint and improving computational efficiency.
### Enabling Gradient Checkpointing
```yaml
gradient_checkpointing: true
```
### Enabling Activation Offloading
```yaml
gradient_checkpointing: true # required for activation offloading
activation_offloading: true
```
Activation offloading variants:
The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
to overlap the communications and computations when offloading.
The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.

View File

@@ -6,19 +6,19 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1 xformers>=0.0.23.post1
autoawq==0.2.7.post3 autoawq==0.2.7.post3
liger-kernel==0.5.10 liger-kernel==0.6.0
# END section # END section
packaging==23.2 packaging==23.2
huggingface_hub==0.32.2 huggingface_hub>=0.33.0
peft==0.15.2 peft==0.16.0
transformers==4.53.1 transformers==4.53.2
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.8.1 accelerate==1.8.1
datasets==3.6.0 datasets==4.0.0
deepspeed>=0.17.0 deepspeed>=0.17.0
trl==0.18.2 trl==0.19.1
hf_xet==1.1.2 hf_xet==1.1.2
optimum==1.16.2 optimum==1.16.2

View File

@@ -73,9 +73,9 @@ def parse_requirements(extras_require_map):
extras_require_map["vllm"] = ["vllm>=0.9.0"] extras_require_map["vllm"] = ["vllm>=0.9.0"]
elif (major, minor) >= (2, 6): elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append( _install_requires.append("xformers==0.0.29.post3")
"xformers==0.0.29.post2" # since we only support 2.6.0+cu126
) # vllm needs post2 w torch 2.6 _dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"] extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 5): elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
@@ -121,7 +121,7 @@ extras_require = {
"yunchang==0.6.0", "yunchang==0.6.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed==0.17.1", "deepspeed==0.17.2",
"deepspeed-kernels", "deepspeed-kernels",
], ],
"mamba-ssm": [ "mamba-ssm": [

View File

@@ -16,7 +16,6 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
validate_config, validate_config,
@@ -227,7 +226,6 @@ def load_cfg(
}, },
) )
migrate_fsdp_config(cfg)
prepare_optim_env(cfg) prepare_optim_env(cfg)
prepare_opinionated_env(cfg) prepare_opinionated_env(cfg)
normalize_config(cfg) normalize_config(cfg)

View File

@@ -1,5 +1,6 @@
"""CLI to run preprocessing of a dataset.""" """CLI to run preprocessing of a dataset."""
import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -95,6 +96,7 @@ def do_cli(
kwargs: Additional keyword arguments to override config file values. kwargs: Additional keyword arguments to override config file values.
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs) parser = transformers.HfArgumentParser(PreprocessCliArgs)

View File

@@ -37,7 +37,6 @@ def do_vllm_serve(
Returns: Returns:
process_id: the process id of the started VLLM server process_id: the process id of the started VLLM server
""" """
patch_vllm_worker()
cfg = load_cfg(config) cfg = load_cfg(config)
model = cfg.base_model model = cfg.base_model
@@ -47,6 +46,9 @@ def do_vllm_serve(
tensor_parallel_size = ( tensor_parallel_size = (
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
) )
data_parallel_size = (
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
)
host = cli_args.get("host") or cfg.vllm.host host = cli_args.get("host") or cfg.vllm.host
port = cli_args.get("port") or cfg.vllm.port port = cli_args.get("port") or cfg.vllm.port
gpu_memory_utilization = ( gpu_memory_utilization = (
@@ -68,6 +70,7 @@ def do_vllm_serve(
vllm_script_args = AxolotlScriptArguments( vllm_script_args = AxolotlScriptArguments(
model=model, model=model,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
host=host, host=host,
port=port, port=port,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,

View File

@@ -112,13 +112,6 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model) plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
) )
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps: if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
@@ -145,6 +138,14 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(GPUStatsCallback(cfg=self.cfg)) callbacks.append(GPUStatsCallback(cfg=self.cfg))
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -418,6 +419,9 @@ class TrainerBuilderBase(abc.ABC):
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True True
) )
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
training_args_kwargs["torch_compile"] = self.cfg.torch_compile training_args_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend: if self.cfg.torch_compile_backend:
training_args_kwargs["torch_compile_backend"] = ( training_args_kwargs["torch_compile_backend"] = (
@@ -426,8 +430,16 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.torch_compile_mode: if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
def _configure_accelerator_config(self, training_args_kwargs: dict):
if self.cfg.accelerator_config:
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
def _configure_gradient_checkpointing(self, training_args_kwargs: dict): def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing: if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False
training_args_kwargs["activation_offloading"] = True
elif self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = ( training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing self.cfg.gradient_checkpointing
) )
@@ -510,5 +522,6 @@ class TrainerBuilderBase(abc.ABC):
self._configure_scheduler(training_args_kwargs) self._configure_scheduler(training_args_kwargs)
self._configure_optimizer(training_args_kwargs, trainer_kwargs) self._configure_optimizer(training_args_kwargs, trainer_kwargs)
self._configure_torch_compile(training_args_kwargs) self._configure_torch_compile(training_args_kwargs)
self._configure_accelerator_config(training_args_kwargs)
return training_args_kwargs, trainer_kwargs return training_args_kwargs, trainer_kwargs

View File

@@ -310,11 +310,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.neftune_noise_alpha self.cfg.neftune_noise_alpha
) )
if self.cfg.accelerator_config:
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.image_size: if self.cfg.image_size:
training_arguments_kwargs["image_size"] = self.cfg.image_size training_arguments_kwargs["image_size"] = self.cfg.image_size
if self.cfg.image_resize_algorithm: if self.cfg.image_resize_algorithm:

View File

@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
from typing_extensions import override from typing_extensions import override
from axolotl.core.trainers.mixins import ( from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin, CheckpointSaveMixin,
OptimizerMixin, OptimizerMixin,
PackingMixin, PackingMixin,
@@ -48,6 +49,7 @@ class AxolotlTrainer(
OptimizerMixin, OptimizerMixin,
RngLoaderMixin, RngLoaderMixin,
CheckpointSaveMixin, CheckpointSaveMixin,
ActivationOffloadingMixin,
Trainer, Trainer,
): ):
"""Extend the base Trainer for axolotl helpers""" """Extend the base Trainer for axolotl helpers"""
@@ -75,18 +77,6 @@ class AxolotlTrainer(
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def _create_multipack_sampler( def _create_multipack_sampler(
self, base_sampler: Sampler, dataset: Dataset self, base_sampler: Sampler, dataset: Dataset
) -> MultipackBatchSampler: ) -> MultipackBatchSampler:

View File

@@ -14,6 +14,7 @@ from axolotl.core.trainers.grpo.trainer import (
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -41,9 +42,18 @@ class GRPOStrategy:
return grpo_args_kwargs return grpo_args_kwargs
trl: TRLConfig = cfg.trl # type: ignore trl: TRLConfig = cfg.trl # type: ignore
vllm_cfg: VllmConfig = cfg.vllm # type: ignore
if trl.use_vllm: if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
if trl.vllm_mode == "colocate":
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
vllm_cfg.gpu_memory_utilization
)
grpo_args_kwargs["vllm_tensor_parallel_size"] = (
vllm_cfg.tensor_parallel_size
)
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined] grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined] grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
if trl.vllm_server_timeout: if trl.vllm_server_timeout:

View File

@@ -59,42 +59,6 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"] _tag_names = ["trl", "grpo", "axolotl"]
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer): class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling""" """Extend the base GRPOTrainer for sequence parallelism handling"""
@@ -252,7 +216,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
if not is_eval: if not is_eval:
dataloader_params["worker_init_fn"] = seed_worker dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
# Create the dataloader # Create the dataloader
dataloader = DataLoader(dataset, **dataloader_params) dataloader = DataLoader(dataset, **dataloader_params)

View File

@@ -3,6 +3,7 @@
# pylint: disable=unused-import # pylint: disable=unused-import
# flake8: noqa # flake8: noqa
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin from .checkpoints import CheckpointSaveMixin
from .optimizer import OptimizerMixin from .optimizer import OptimizerMixin
from .packing import PackingMixin from .packing import PackingMixin

View File

@@ -0,0 +1,37 @@
"""
Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
from trl.models.activation_offloading import get_act_offloading_ctx_manager
class ActivationOffloadingMixin(Trainer):
"""
Trainer mixin class for activation checkpointing w offloading
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.args.activation_offloading:
self.activation_offload_context = get_act_offloading_ctx_manager(
self.model, use_streams=True
)
else:
self.activation_offload_context = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self.activation_offload_context:
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)

View File

@@ -217,6 +217,11 @@ class AxolotlTrainingMixins:
}, },
) )
activation_offloading: bool | None = field(
default=None,
metadata={"help": "Use activation offloading with CUDA streams for training."},
)
# multi-modal section # multi-modal section
image_size: int | tuple[int, int] | None = field( image_size: int | tuple[int, int] | None = field(

View File

@@ -6,15 +6,21 @@ from typing import Optional, Union, Unpack
import torch import torch
from transformers import Cache from transformers import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import LossKwargs
try:
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.utils import LossKwargs
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): class TransformersKwargs(FlashAttentionKwargs, LossKwargs):
""" """
placeholder kwargs for hf model classes placeholder kwargs for hf model classes
""" """
except ImportError:
from transformers.utils.generic import ( # type: ignore[no-redef]
TransformersKwargs,
)
def kldiv_forward_llama_like( def kldiv_forward_llama_like(
@@ -33,7 +39,7 @@ def kldiv_forward_llama_like(
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc] **kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
) -> CausalLMOutputWithPast: ) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_attentions = ( output_attentions = (

View File

@@ -198,12 +198,22 @@ class ModelLoader:
): ):
self.model = self.model.merge_and_unload() self.model = self.model.merge_and_unload()
self._apply_activation_checkpointing()
self._resize_token_embeddings() self._resize_token_embeddings()
self._adjust_model_config() self._adjust_model_config()
self._configure_embedding_dtypes() self._configure_embedding_dtypes()
self._configure_qat() self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0) log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
ac_wrap_hf_model,
)
# ^^ importing this at the module level breaks plugins
ac_wrap_hf_model(self.model)
def _resize_token_embeddings(self): def _resize_token_embeddings(self):
"""Resize token embeddings if needed.""" """Resize token embeddings if needed."""
embeddings_len = ( embeddings_len = (

View File

@@ -7,7 +7,6 @@ import importlib.util
from functools import cached_property from functools import cached_property
import addict import addict
import torch
import transformers import transformers
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
@@ -168,28 +167,19 @@ class PatchManager:
def _apply_gradient_checkpointing_patches(self): def _apply_gradient_checkpointing_patches(self):
"""Apply patches for gradient checkpointing.""" """Apply patches for gradient checkpointing."""
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]: if (
self.cfg.gradient_checkpointing
and self.cfg.activation_offloading == "legacy"
):
from axolotl.monkeypatch.gradient_checkpointing import ( from axolotl.monkeypatch.gradient_checkpointing import (
CheckpointFunctionWithCPUOffload,
hf_grad_checkpoint_offload_wrapper, hf_grad_checkpoint_offload_wrapper,
) )
if ( transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
self.cfg.gradient_checkpointing_kwargs elif (
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs self.cfg.gradient_checkpointing
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False and self.cfg.activation_offloading == "offload_disk"
): ):
transformers.modeling_utils.checkpoint = (
hf_grad_checkpoint_offload_wrapper
)
else:
transformers.modeling_utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
torch.utils.checkpoint.CheckpointFunction = (
CheckpointFunctionWithCPUOffload
)
if self.cfg.gradient_checkpointing == "offload_disk":
from axolotl.monkeypatch.gradient_checkpointing import ( from axolotl.monkeypatch.gradient_checkpointing import (
hf_grad_checkpoint_disk_offload_wrapper, hf_grad_checkpoint_disk_offload_wrapper,
) )

View File

@@ -6,7 +6,6 @@ from functools import partial
from packaging import version from packaging import version
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401 from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
CheckpointFunctionWithCPUOffload,
CPU_Offloaded_Gradient_Checkpointer, CPU_Offloaded_Gradient_Checkpointer,
) )
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import ( from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (

View File

@@ -14,18 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import inspect import inspect
import torch import torch
from packaging import version from packaging import version
from torch.utils.checkpoint import ( from torch.utils.checkpoint import (
_get_autocast_kwargs,
_get_device_module,
_infer_device_type,
check_backward_validity,
detach_variable,
get_device_states,
set_device_states, set_device_states,
) )
@@ -76,153 +69,3 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
) + ( ) + (
None, None,
) * len(ctx.args) ) * len(ctx.args)
# Copyright 2025 Snowflake Inc.
# SPDX-License-Identifier: Apache-2.0
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
"""
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
"""
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.device_type = _infer_device_type(*args)
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
ctx.device_type
)
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_device_in_fwd = False
device_module = _get_device_module(ctx.device_type)
if getattr(device_module, "_initialized", False):
ctx.had_device_in_fwd = True
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
# x = None
for i, arg in enumerate(args):
if torch.is_tensor(arg):
# cpu-offload
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
# upstream could accept a list of arg indices to offload
if i == 0:
# print(f"{arg.shape=}")
ctx.x_device = arg.device
ctx.x_requires_grad = arg.requires_grad
t = arg.detach().cpu()
else:
t = arg
tensor_inputs.append(t)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if (
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
):
raise RuntimeError(
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
" with .grad() or passing an `inputs` parameter to .backward()."
" To resolve this error, you can either set use_reentrant=False,"
" or call .backward() without passing the `inputs` argument."
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
if i == 0:
t = (
tensors[i]
.to(ctx.x_device)
.detach()
.requires_grad_(ctx.x_requires_grad)
)
else:
t = tensors[i]
inputs[idx] = t
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
rng_devices = ctx.fwd_devices
with torch.random.fork_rng(
devices=rng_devices,
enabled=ctx.preserve_rng_state,
device_type=ctx.device_type,
):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_device_in_fwd:
if has_device_type:
# newer pytorch (as early as 2.7)
set_device_states(
ctx.fwd_devices,
ctx.fwd_device_states,
device_type=ctx.device_type,
)
else:
# older pytorch (at least 2.4)
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))
device_autocast_ctx = (
torch.amp.autocast(
device_type=ctx.device_type, **ctx.device_autocast_kwargs
)
if torch.amp.is_autocast_available(ctx.device_type)
else contextlib.nullcontext()
)
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True, this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
return (None, None) + grads

View File

@@ -379,6 +379,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Public method that can handle either a single prompt or a batch of prompts. Public method that can handle either a single prompt or a batch of prompts.
""" """
def _remove_none_values(obj):
"""
Remove null from a dictionary-like obj or list.
These can appear due to Dataset loading causing schema merge.
See https://github.com/axolotl-ai-cloud/axolotl/pull/2909
"""
if hasattr(obj, "items"):
return {
k: _remove_none_values(v) for k, v in obj.items() if v is not None
}
if isinstance(obj, list):
return [_remove_none_values(elem) for elem in obj]
return obj
prompt = _remove_none_values(prompt)
if not self.is_prompt_batched(prompt) or not self.supports_batched: if not self.is_prompt_batched(prompt) or not self.supports_batched:
return self._tokenize_single_prompt(prompt) return self._tokenize_single_prompt(prompt)

View File

@@ -224,6 +224,9 @@ def execute_training(
# torch.set_default_dtype(torch.bfloat16) # torch.set_default_dtype(torch.bfloat16)
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, trainer.model)
def save_trained_model( def save_trained_model(
cfg: DictDefault, cfg: DictDefault,
@@ -510,6 +513,9 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
peft_config=peft_config, peft_config=peft_config,
) )
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
return ( return (
trainer, trainer,
model, model,
@@ -541,9 +547,6 @@ def train(
processor, processor,
) = setup_model_and_trainer(cfg, dataset_meta) ) = setup_model_and_trainer(cfg, dataset_meta)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Handle untrained tokens if configured # Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
@@ -566,6 +569,4 @@ def train(
if not cfg.use_ray: if not cfg.use_ray:
cleanup_distributed() cleanup_distributed()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer return model, tokenizer, trainer

View File

@@ -841,21 +841,35 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
class GCCallback(TrainerCallback): class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache""" """Callback to garbage collect torch cache"""
def __init__(self, gc_steps=None): def __init__(self, gc_steps: int | None = -1):
self.gc_steps = gc_steps self.gc_steps: int = gc_steps or -1
self.next_gc_on_begin_step: int = -1
def _gc(self):
torch.cuda.empty_cache()
gc.collect()
def on_step_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if self.next_gc_on_begin_step == state.global_step:
self._gc()
def on_step_end( def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument self, args, state, control, **kwargs # pylint: disable=unused-argument
): ):
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0: if control.should_evaluate:
torch.cuda.empty_cache() # automatically GC before evals so the eval memory spike from the CEL doesn't OOM the trainer
gc.collect() self._gc()
# also GC on the start of the next step after the eval
self.next_gc_on_begin_step = state.global_step + 1
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
self._gc()
def on_epoch_end( def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument self, args, state, control, **kwargs # pylint: disable=unused-argument
): ):
torch.cuda.empty_cache() self._gc()
gc.collect()
def colab_inference_post_train_callback(trainer: Trainer): def colab_inference_post_train_callback(trainer: Trainer):

View File

@@ -19,9 +19,27 @@ class PytorchProfilerCallback(TrainerCallback):
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
""" """
def __init__(self, steps_to_profile: int = 5): def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
self.steps_to_profile = steps_to_profile # steps are 0 indexed, so to start at 0-th step, we start at beginning of first step,
if self.steps_to_profile: # and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4]
self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all"
)
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
def on_step_begin( # pylint: disable=unused-argument
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all" enabled="all"
) )
@@ -33,7 +51,28 @@ class PytorchProfilerCallback(TrainerCallback):
control: TrainerControl, # pylint: disable=unused-argument control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
if state.global_step == self.steps_to_profile: if state.global_step == self.profiler_steps_end:
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
dump(snapshot, fout)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled=None
)
def on_train_end( # pylint: disable=unused-argument
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# make sure to record if we happen to have more steps than steps to profile
if (
state.global_step >= self.profiler_steps_start
and state.global_step < self.profiler_steps_end
):
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout: with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
dump(snapshot, fout) dump(snapshot, fout)

View File

@@ -314,16 +314,3 @@ def prepare_plugins(cfg):
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]: for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name) plugin_manager.register(plugin_name)
# TODO @SalmanMohammadi remove this function in 0.12
def migrate_fsdp_config(cfg):
if cfg.get("fsdp_config"):
fsdp_config_keys = cfg.fsdp_config.keys()
if "fsdp_version" in fsdp_config_keys:
cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")
for key in list(fsdp_config_keys):
if key.startswith("fsdp_") and key != "fsdp_version":
cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
del cfg.fsdp_config[key]

View File

@@ -497,3 +497,131 @@ class HFMistralTokenizer:
return [ return [
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
] ]
def __call__(
self,
text: str | list[str],
add_special_tokens: bool = True,
padding: bool | str = False,
truncation: bool = False,
max_length: int | None = None,
return_tensors: str | None = None,
**kwargs,
) -> dict[str, list[int] | np.ndarray | Tensor]:
"""
Tokenize text and return a dictionary with input_ids and attention_mask.
Args:
text: Input text string or list of strings to tokenize.
add_special_tokens: Whether to add special tokens (BOS/EOS).
padding: Whether to pad sequences. Can be True, False, "longest", or "max_length".
truncation: Whether to truncate sequences to max_length.
max_length: Maximum sequence length for truncation/padding.
return_tensors: Return format ("pt" for PyTorch, "np" for NumPy, None for lists).
Returns:
Dictionary with "input_ids" and "attention_mask" keys.
"""
# if kwargs passed, raise error
if kwargs:
raise ValueError(
f"Unsupported kwargs: {kwargs}. Please create an issue on GitHub."
)
# `np` can work with inhomogeneous shapes but let's not support it until needed.
if (
isinstance(text, list)
and len(text) > 1
and return_tensors in ("pt", "np")
and padding is False
and truncation is False
):
raise ValueError(
"return_tensors='pt' or 'np' requires padding or truncation."
)
# Handle single string input
if isinstance(text, str):
text = [text]
# Encode all texts
# TODO: figure out how to parallelize this
batch_input_ids = []
for single_text in text:
input_ids = self.encode(single_text, add_special_tokens=add_special_tokens)
# Handle truncation
if truncation and max_length is not None and len(input_ids) > max_length:
input_ids = input_ids[:max_length]
batch_input_ids.append(input_ids)
# Create attention masks (1 for real tokens, 0 for padding)
attention_masks = [[1] * len(input_ids) for input_ids in batch_input_ids]
# Handle padding
if padding in (True, "longest"):
# Pad to longest sequence in batch
max_len = max(len(input_ids) for input_ids in batch_input_ids)
for i, input_ids in enumerate(batch_input_ids):
pad_length = max_len - len(input_ids)
if pad_length > 0:
if self.padding_side == "right":
batch_input_ids[i] = (
input_ids + [self.pad_token_id] * pad_length
)
attention_masks[i] = attention_masks[i] + [0] * pad_length
else: # left padding
batch_input_ids[i] = [
self.pad_token_id
] * pad_length + input_ids
attention_masks[i] = [0] * pad_length + attention_masks[i]
elif padding == "max_length":
if max_length is None:
raise ValueError(
"max_length must be specified when padding='max_length'"
)
for i, input_ids in enumerate(batch_input_ids):
pad_length = max_length - len(input_ids)
if pad_length > 0:
if self.padding_side == "right":
batch_input_ids[i] = (
input_ids + [self.pad_token_id] * pad_length
)
attention_masks[i] = attention_masks[i] + [0] * pad_length
else: # left padding
batch_input_ids[i] = [
self.pad_token_id
] * pad_length + input_ids
attention_masks[i] = [0] * pad_length + attention_masks[i]
# Prepare result
result = {}
# Handle return tensor format
if return_tensors == "pt":
import torch
result["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long)
result["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
elif return_tensors == "np":
result["input_ids"] = np.array(batch_input_ids, dtype=np.int64)
result["attention_mask"] = np.array(attention_masks, dtype=np.int64)
elif return_tensors is None:
result["input_ids"] = batch_input_ids
result["attention_mask"] = attention_masks
else:
raise ValueError(
f"Unsupported return_tensors='{return_tensors}'. "
"Only 'pt' and 'np' are supported."
)
# If single input, return single sequences (not batched)
if len(text) == 1 and return_tensors is None:
result["input_ids"] = result["input_ids"][0]
result["attention_mask"] = result["attention_mask"][0]
return result

View File

@@ -320,7 +320,12 @@ class AxolotlInputConfig(
}, },
) )
gc_steps: int | None = None gc_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled)."
},
)
bf16: Literal["auto"] | bool | None = Field( bf16: Literal["auto"] | bool | None = Field(
default="auto", default="auto",
@@ -360,6 +365,12 @@ class AxolotlInputConfig(
"description": "Additional kwargs to pass to the trainer for gradient checkpointing" "description": "Additional kwargs to pass to the trainer for gradient checkpointing"
}, },
) )
activation_offloading: Literal["legacy", "disk"] | bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
},
)
unfrozen_parameters: list[str] | None = None unfrozen_parameters: list[str] | None = None
@@ -573,6 +584,12 @@ class AxolotlInputConfig(
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json" "description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
}, },
) )
deepcompile: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use deepcompile for faster training with deepspeed"
},
)
fsdp: list[str] | None = Field( fsdp: list[str] | None = Field(
default=None, default=None,
json_schema_extra={"description": "FSDP configuration"}, json_schema_extra={"description": "FSDP configuration"},
@@ -618,7 +635,12 @@ class AxolotlInputConfig(
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case." "description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
}, },
) )
tensor_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
},
)
special_tokens: SpecialTokensConfig | None = Field( special_tokens: SpecialTokensConfig | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -730,6 +752,12 @@ class AxolotlInputConfig(
"description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz" "description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz"
}, },
) )
profiler_steps_start: int | None = Field(
default=0,
json_schema_extra={
"description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run."
},
)
include_tokens_per_second: bool | None = Field( include_tokens_per_second: bool | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
@@ -1143,72 +1171,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version(cls, data):
fsdp_config = data.get("fsdp_config", {})
if fsdp_config and str(data.get("fsdp_version")) != "2":
LOG.info(
"FSDP1 will be deprecated in an upcoming release of Axolotl."
"We recommend that you use FSDP version 2 for better performance and compatibility. "
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
"For more details on migrating your config. "
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
fsdp_config = data.get("fsdp_config")
if fsdp_config and data.get("fsdp_version") == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_dpo(cls, data):
if data.get("fsdp_version") == 2 and data.get("rl") in [
RLType.DPO,
RLType.KTO,
RLType.ORPO,
RLType.IPO,
]:
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
"FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
if fsdp_config := data.get("fsdp_config"):
if fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def default_dataloader_opts(cls, data): def default_dataloader_opts(cls, data):

View File

@@ -1,5 +1,7 @@
"""Pydantic models for TRL trainer configuration""" """Pydantic models for TRL trainer configuration"""
from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -27,6 +29,12 @@ class TRLConfig(BaseModel):
default=False, default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training."}, json_schema_extra={"description": "Whether to use VLLM for RL training."},
) )
vllm_mode: Literal["server", "colocate"] | None = Field(
default=None,
json_schema_extra={
"description": "VLLM mode to use, one of 'server' or 'colocate'"
},
)
vllm_server_host: str | None = Field( vllm_server_host: str | None = Field(
default="0.0.0.0", # nosec B104 default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host of the vLLM server to connect to."}, json_schema_extra={"description": "Host of the vLLM server to connect to."},

View File

@@ -1,8 +1,11 @@
"""Module with validation methods for config pydantic model.""" """Module with validation methods for config pydantic model."""
# pylint: disable=too-many-lines # pylint: disable=too-many-boolean-expressions
import json
import logging import logging
import tempfile
from pathlib import Path
from pydantic import ( from pydantic import (
field_validator, field_validator,
@@ -12,6 +15,8 @@ from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
# pylint: disable=too-many-lines
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
@@ -748,43 +753,181 @@ class OptimizationValidationMixin:
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_fsdp_offload_w_8bit_optimizer(cls, data): def check_fsdp_version(cls, data):
if ( fsdp_config = data.get("fsdp_config", {})
data.get("fsdp") if fsdp_config and str(data.get("fsdp_version")) != "2":
and "8bit" in data.get("optimizer", "") LOG.info(
and data.get("fsdp_config") "FSDP1 will be deprecated in an upcoming release of Axolotl."
and data["fsdp_config"].get("fsdp_offload_params") "We recommend that you use FSDP version 2 for better performance and compatibility. "
and str(data["fsdp_config"].get("fsdp_version")) != "2" "Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
): "For more details on migrating your config. "
raise ValueError(
f"FSDP Offload not compatible with {data.get('optimizer')}"
) )
if ( return data
data.get("fsdp")
and "8bit" in data.get("optimizer", "") @model_validator(mode="after")
and data.get("fsdp_config") def check_fsdp2_base_model_quant_ram_efficient_loading(self):
and str(data["fsdp_config"].get("fsdp_version")) == "2" fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
): fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]: load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
# CUDA ops errors with bnb 8bit optimizer + FSDP2 load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
if fsdp_config and fsdp_version == 2:
if fsdp_config.get("cpu_ram_efficient_loading") and (
load_in_8bit or load_in_4bit
):
raise ValueError( raise ValueError(
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead" "FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
)
return self
@model_validator(mode="before")
@classmethod
def check_fsdp2_base_model_quant_rl(cls, data):
if data.get("fsdp_version") == 2 and data.get("rl") in [
RLType.DPO,
RLType.KTO,
RLType.ORPO,
RLType.IPO,
]:
if data.get("load_in_8bit") or data.get("load_in_4bit"):
raise ValueError(
f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1."
) )
return data return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_fsdp_sharded_state_dict_w_safetensors(cls, data): def check_fsdp_version_in_fsdp_config(cls, data):
if data.get("fsdp_config"):
if data.get("fsdp_config", {}).get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if ( if (
data.get("fsdp_config") hasattr(self, "fsdp_config")
and data.get("save_safetensors") and self.fsdp_config
and data.get("fsdp_config") and self.optimizer
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT" and "8bit" in self.optimizer.value
and self.fsdp_config["offload_params"]
and str(self.fsdp_version) != "2"
):
raise ValueError(
f"FSDP Offload not compatible with {str(self.optimizer.value)}"
)
return self
@model_validator(mode="after")
def check_fsdp2_w_8bit_optimizer(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and self.optimizer
and "8bit" in self.optimizer.value
and str(self.fsdp_version) == "2"
):
if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]:
# CUDA ops errors with bnb 8bit optimizer + FSDP2
raise ValueError(
f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead"
)
return self
@model_validator(mode="after")
def check_fsdp_sharded_state_dict_w_safetensors(self):
if (
hasattr(self, "fsdp_config")
and self.fsdp_config
and hasattr(self, "save_safetensors")
and self.save_safetensors
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
and str(getattr(self, "fsdp_version", "1")) != "2"
): ):
raise ValueError( raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors" "FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
) )
return self
@model_validator(mode="before")
@classmethod
def check_tensor_parallel_size_update_ds_json(cls, data):
tensor_parallel_size = data.get("tensor_parallel_size")
if tensor_parallel_size is not None and tensor_parallel_size > 1:
if not data.get("deepspeed"):
raise ValueError(
"Tensor parallelism (TP) is only supported with DeepSpeed"
)
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
should_save = False
if "tensor_parallel" not in ds_config:
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
should_save = True
if (
"gather_16bit_weights_on_model_save"
not in ds_config["zero_optimization"]
):
ds_config["zero_optimization"][
"gather_16bit_weights_on_model_save"
] = True
should_save = True
if should_save:
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
return data
@model_validator(mode="before")
@classmethod
def check_deepcompile(cls, data):
deepcompile = data.get("deepcompile")
if deepcompile:
if not data.get("deepspeed"):
raise ValueError("DeepCompile is only supported with DeepSpeed")
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
ds_config = json.load(ds_fin)
if "compile" not in ds_config:
ds_config["compile"] = {"deepcompile": True}
temp_dir = tempfile.mkdtemp()
with open(
Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8"
) as ds_fout:
json.dump(ds_config, ds_fout, indent=4)
data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json")
return data return data
@@ -932,6 +1075,28 @@ class ModelCompatibilityValidationMixin:
self.gradient_checkpointing = "offload" self.gradient_checkpointing = "offload"
return self return self
@model_validator(mode="after")
def check_gradient_checkpointing_w_offload(self):
if self.gradient_checkpointing == "offload":
LOG.warning(
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`"
)
self.gradient_checkpointing = True
self.activation_offloading = True
if self.gradient_checkpointing == "offload_disk":
LOG.warning(
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
)
self.gradient_checkpointing = True
self.activation_offloading = "disk"
return self
@model_validator(mode="after")
def check_activation_offloading_wo_gc(self):
if self.activation_offloading and not self.gradient_checkpointing:
raise ValueError("activation_offloading requires gradient_checkpointing")
return self
@model_validator(mode="after") @model_validator(mode="after")
def check_better_transformers(self): def check_better_transformers(self):
if self.flash_optimum is True: if self.flash_optimum is True:
@@ -1019,6 +1184,12 @@ class ComplexValidationMixin:
) )
return self return self
@model_validator(mode="after")
def check_tensor_parallel_size(self):
if not self.tensor_parallel_size:
self.tensor_parallel_size = 1
return self
@model_validator(mode="after") @model_validator(mode="after")
def check_sequence_parallel_degree(self): def check_sequence_parallel_degree(self):
if not self.sequence_parallel_degree: if not self.sequence_parallel_degree:

View File

@@ -18,6 +18,10 @@ class VllmConfig(BaseModel):
default=None, default=None,
json_schema_extra={"description": "Tensor parallel size for VLLM"}, json_schema_extra={"description": "Tensor parallel size for VLLM"},
) )
data_parallel_size: int | None = Field(
default=None,
json_schema_extra={"description": "Data parallel size for VLLM"},
)
gpu_memory_utilization: float | None = Field( gpu_memory_utilization: float | None = Field(
default=0.9, default=0.9,
json_schema_extra={"description": "GPU memory utilization for VLLM"}, json_schema_extra={"description": "GPU memory utilization for VLLM"},

View File

@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
) )
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.sequence_parallel_degree
* cfg.tensor_parallel_size
) )
LOG.debug( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}" f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
@@ -481,7 +482,10 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
# on the agreed on value for sample_packing_eff_est # on the agreed on value for sample_packing_eff_est
total_num_steps = int( total_num_steps = int(
math.floor( math.floor(
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree data_loader_len
* cfg.num_epochs
* cfg.sequence_parallel_degree
* cfg.tensor_parallel_size
) )
) )
if cfg.dataloader_drop_last: if cfg.dataloader_drop_last:
@@ -508,6 +512,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
len(train_dataset) len(train_dataset)
* cfg.num_epochs * cfg.num_epochs
* cfg.sequence_parallel_degree * cfg.sequence_parallel_degree
* cfg.tensor_parallel_size
/ cfg.batch_size / cfg.batch_size
) )
) )
@@ -546,7 +551,10 @@ def setup_deepspeed_env(cfg, stage=None):
# NOTE(djsaunde): The distribued state cannot be initialized prior to the # NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior # ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load. # to model load.
if int(os.environ.get("WORLD_SIZE", "1")) == 1: if (
int(os.environ.get("WORLD_SIZE", "1")) == 1
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
):
os.environ["WORLD_SIZE"] = "1" # force it in case not set os.environ["WORLD_SIZE"] = "1" # force it in case not set
os.environ["LOCAL_RANK"] = "0" # force it in case not set os.environ["LOCAL_RANK"] = "0" # force it in case not set
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0") os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")

View File

@@ -22,6 +22,8 @@ from huggingface_hub.errors import LocalEntryNotFoundError
from tokenizers import AddedToken from tokenizers import AddedToken
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import ( from tests.hf_offline_utils import (
enable_hf_offline, enable_hf_offline,
hf_offline_context, hf_offline_context,
@@ -539,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
return datasets.load_from_disk(ds_path)["train"] return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture(name="min_base_cfg")
def fixture_min_base_cfg():
return DictDefault(
base_model="HuggingFaceTB/SmolLM2-135M",
learning_rate=1e-3,
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
micro_batch_size=1,
gradient_accumulation_steps=1,
)
# # pylint: disable=redefined-outer-name,unused-argument # # pylint: disable=redefined-outer-name,unused-argument
@pytest.mark.skipif( @pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1", os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",

View File

@@ -65,6 +65,7 @@ def fixture_base_cfg():
"dataloader_pin_memory": True, "dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2, "dataloader_prefetch_factor": 2,
"sequence_parallel_degree": 1, "sequence_parallel_degree": 1,
"tensor_parallel_size": 1,
# Dtype # Dtype
"fp16": False, "fp16": False,
"bf16": False, "bf16": False,

View File

@@ -141,6 +141,7 @@ def recursive_kill(process: subprocess.Popen):
os.kill(process.pid, 9) os.kill(process.pid, 9)
@pytest.mark.skip(reason="flaky vllm tests in modal")
class TestGRPO: class TestGRPO:
""" """
Test case for GRPO training using multilpe GPUs Test case for GRPO training using multilpe GPUs

View File

@@ -391,7 +391,10 @@ class TestMultiGPULlama:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fsdp_state_dict_type", "fsdp_state_dict_type",
["FULL_STATE_DICT", "SHARDED_STATE_DICT"], [
"FULL_STATE_DICT",
# "SHARDED_STATE_DICT", # not supported since intermediate checkpoints fail with fsdp1
],
) )
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type): def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -413,7 +416,8 @@ class TestMultiGPULlama:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"max_steps": 2, "max_steps": 3,
"save_steps": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 2,
# "gradient_checkpointing": True, # "gradient_checkpointing": True,
@@ -597,7 +601,7 @@ class TestMultiGPULlama:
"fsdp_use_orig_params": False, "fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": True, "fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", "fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT", "fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP", "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
}, },
"use_tensorboard": True, "use_tensorboard": True,
@@ -707,7 +711,7 @@ class TestMultiGPULlama:
) )
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.4, "Train Loss (%s) is too high" temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -0,0 +1,58 @@
"""E2E Test the preprocess cli"""
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
AXOLOTL_ROOT = Path(__file__).parent.parent.parent
class TestPreprocess:
"""test cases for preprocess"""
def test_w_deepspeed(self, temp_dir):
"""make sure preproces doesn't choke when using deepspeed in the config"""
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": "auto",
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"dataset_prepared_path": temp_dir + "/last_run_prepared",
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"preprocess",
str(Path(temp_dir) / "config.yaml"),
]
)
assert (Path(temp_dir) / "last_run_prepared").exists()

113
tests/e2e/test_profiler.py Normal file
View File

@@ -0,0 +1,113 @@
"""
e2e gpu test for the pytorch profiler callback
"""
from pathlib import Path
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="profiler_base_cfg")
def fixture_profiler_base_cfg():
cfg = DictDefault(
base_model="HuggingFaceTB/SmolLM2-135M",
tokenizer_type="AutoTokenizer",
sequence_len=1024,
load_in_8bit=True,
adapter="lora",
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_target_linear=True,
val_set_size=0.02,
special_tokens={"pad_token": "<|endoftext|>"},
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
num_epochs=1,
micro_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=0.00001,
optimizer="adamw_torch_fused",
lr_scheduler="cosine",
)
return cfg
class TestProfiler:
"""
test cases for the pytorch profiler callback
"""
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=1,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
@pytest.mark.parametrize(
"profiler_steps_start",
[3, 5],
)
def test_profiler_saves_past_end(
self, profiler_base_cfg, temp_dir, profiler_steps_start
):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=profiler_steps_start,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=6,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert not (Path(temp_dir) / "snapshot.pickle").exists()

View File

@@ -0,0 +1,75 @@
"""
Tests for chat template prompt strategy with schema unification for none fields
"""
import json
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import StrategyLoader
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="messages_w_tools")
def fixture_messages_w_tools():
jsons = """
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
""".strip().split(
"\n"
)
rows = [json.loads(row) for row in jsons]
return Dataset.from_list(rows)
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@pytest.fixture(name="qwen3_prompt_strategy")
def qwen3_chat_template_strategy(qwen3_tokenizer):
cfg = DictDefault(
sequence_len=2048,
chat_template="qwen3",
eot_tokens=["<|im_end|>"],
)
ds_cfg = DictDefault(
type="chat_template",
)
load = StrategyLoader()
strat = load(qwen3_tokenizer, cfg, ds_cfg)
return strat
class TestSchemaUnification:
"""
Test class on handling null fields for tool calling
"""
def test_schema_unification_single_prompt(
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
):
for row in messages_w_tools:
inputs = qwen3_prompt_strategy.tokenize_prompt(row)
decoded = qwen3_tokenizer.decode(inputs["input_ids"])
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
assert '"message": null' not in tool_call
assert '"theta": null' not in tool_call
def test_schema_unification_batched(
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
):
rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)
for row in rows:
decoded = qwen3_tokenizer.decode(row["input_ids"])
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
assert '"message": null' not in tool_call
assert '"theta": null' not in tool_call

View File

@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
import pytest import pytest
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
@@ -748,5 +750,100 @@ def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
assert "Not the same number of function calls and responses" in str(e) assert "Not the same number of function calls and responses" in str(e)
def test_magistral_tokenizer_call_method(
magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
):
"""Test the __call__ method behavior matches HuggingFace standards"""
from copy import deepcopy
import numpy as np
import torch
hf_tokenizer = deepcopy(llama3_tokenizer)
hf_tokenizer.pad_token = hf_tokenizer.eos_token
test_text = "Hello, how are you?"
batch_texts = ["Hello world", "How are you?"]
# Test single string with return_tensors=None
hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
mistral_result: dict[str, list[int]] = magistral_tokenizer(
test_text, return_tensors=None
)
assert isinstance(mistral_result, dict)
assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
assert isinstance(
mistral_result["attention_mask"], type(hf_result["attention_mask"])
)
assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
assert np.all(mistral_result["attention_mask"])
assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
# Test single string with return_tensors='pt'
hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
test_text, return_tensors="pt"
)
# Check structure and types
assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
# Check shapes match (don't compare token dimension)
assert len(hf_result_pt["input_ids"].shape) == len(
mistral_result_pt["input_ids"].shape
)
assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
assert (
mistral_result_pt["attention_mask"].shape
== mistral_result_pt["input_ids"].shape
)
assert torch.all(mistral_result_pt["attention_mask"] == 1)
# Test batch input with padding
hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
batch_texts, return_tensors="pt", padding=True
)
mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
batch_texts, return_tensors="pt", padding=True
)
# Check batch behavior
assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
assert torch.any(
mistral_batch["attention_mask"][0] == 0
) # padding in shorter sequence
assert torch.all(
mistral_batch["attention_mask"][1] == 1
) # no padding in longer sequence
# Test numpy tensors
mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
test_text, return_tensors="np"
)
assert isinstance(mistral_result_np["input_ids"], np.ndarray)
assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
# Test consistency with encode()
encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
called: dict[str, torch.Tensor] = magistral_tokenizer(
test_text, return_tensors="pt"
)
assert encoded == called["input_ids"][0].tolist()
# Test Error handling
with pytest.raises(ValueError, match="Unsupported kwargs"):
magistral_tokenizer(test_text, unsupported_param=True)
with pytest.raises(
ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
):
magistral_tokenizer(batch_texts, return_tensors="pt")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -6,9 +6,9 @@ import unittest
from unittest.mock import patch from unittest.mock import patch
from axolotl.utils.config import ( from axolotl.utils.config import (
migrate_fsdp_config,
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
validate_config,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -27,6 +27,13 @@ class NormalizeConfigTestCase(unittest.TestCase):
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"learning_rate": 0.0001,
} }
) )
@@ -97,7 +104,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
def test_migrate_fsdp_config(self): def test_migrate_fsdp_config(self):
"""Test basic FSDP config migration with and without fsdp_version""" """Test basic FSDP config migration with and without fsdp_version"""
cfg_with_version = DictDefault( cfg_with_version = self._get_base_cfg() | DictDefault(
{ {
"fsdp_config": { "fsdp_config": {
"fsdp_version": 2, "fsdp_version": 2,
@@ -109,7 +116,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
} }
) )
migrate_fsdp_config(cfg_with_version) cfg_with_version = validate_config(cfg_with_version)
self.assertEqual(cfg_with_version.fsdp_version, 2) self.assertEqual(cfg_with_version.fsdp_version, 2)
self.assertEqual( self.assertEqual(
@@ -125,7 +132,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config) self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
self.assertNotIn("version", cfg_with_version.fsdp_config) self.assertNotIn("version", cfg_with_version.fsdp_config)
cfg_without_version = DictDefault( cfg_without_version = self._get_base_cfg() | DictDefault(
{ {
"fsdp_config": { "fsdp_config": {
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP", "fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
@@ -135,7 +142,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
} }
) )
migrate_fsdp_config(cfg_without_version) cfg_without_version = validate_config(cfg_without_version)
self.assertNotIn("fsdp_version", cfg_without_version) self.assertNotIn("fsdp_version", cfg_without_version)
self.assertEqual( self.assertEqual(
@@ -149,26 +156,25 @@ class NormalizeConfigTestCase(unittest.TestCase):
def test_migrate_fsdp_config_no_fsdp_config(self): def test_migrate_fsdp_config_no_fsdp_config(self):
"""Test that function doesn't crash when no fsdp_config is present""" """Test that function doesn't crash when no fsdp_config is present"""
cfg = DictDefault({"some_other_config": "value"}) cfg = self._get_base_cfg()
migrate_fsdp_config(cfg) cfg = validate_config(cfg)
self.assertNotIn("fsdp_config", cfg) self.assertNotIn("fsdp_config", cfg)
self.assertNotIn("fsdp_version", cfg) self.assertNotIn("fsdp_version", cfg)
self.assertEqual(cfg.some_other_config, "value")
def test_migrate_fsdp_config_empty_fsdp_config(self): def test_migrate_fsdp_config_empty_fsdp_config(self):
"""Test migration with empty fsdp_config""" """Test migration with empty fsdp_config"""
cfg = DictDefault({"fsdp_config": {}}) cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}})
migrate_fsdp_config(cfg) cfg = validate_config(cfg)
self.assertNotIn("fsdp_version", cfg) self.assertNotIn("fsdp_version", cfg)
self.assertEqual(cfg.fsdp_config, {}) self.assertEqual(cfg.fsdp_config, {})
def test_migrate_fsdp_config_mixed_keys(self): def test_migrate_fsdp_config_mixed_keys(self):
"""Test migration with a mix of fsdp_ and non-fsdp_ keys""" """Test migration with a mix of fsdp_ and non-fsdp_ keys"""
cfg = DictDefault( cfg = self._get_base_cfg() | DictDefault(
{ {
"fsdp_config": { "fsdp_config": {
"fsdp_version": 1, "fsdp_version": 1,
@@ -180,7 +186,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
} }
) )
migrate_fsdp_config(cfg) cfg = validate_config(cfg)
self.assertEqual(cfg.fsdp_version, 1) self.assertEqual(cfg.fsdp_version, 1)
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT") self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")

View File

@@ -7,21 +7,16 @@ from axolotl.utils.dict import DictDefault
@pytest.fixture(name="train_base_cfg") @pytest.fixture(name="train_base_cfg")
def fixture_train_base_cfg(): def fixture_train_base_cfg(min_base_cfg):
return DictDefault( return (
base_model="gpt2", DictDefault(
learning_rate=1e-3, micro_batch_size=2,
datasets=[ gradient_accumulation_steps=4,
{ sequence_len=2048,
"path": "mhenrichsen/alpaca_2k_test", sample_packing=True,
"type": "alpaca", num_epochs=1,
}, )
], | min_base_cfg
micro_batch_size=2,
gradient_accumulation_steps=4,
sequence_len=2048,
sample_packing=True,
num_epochs=1,
) )

View File

@@ -0,0 +1,139 @@
"""
tests for pydantic fsdp validation
"""
# pylint: disable=too-many-boolean-expressions
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestFSDPValidation:
"""
test class for pydantic fsdp validation
"""
def test_fsdp_version_in_fsdp_config(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"fsdp_version": 2,
},
)
cfg = validate_config(
cfg,
)
assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version is None
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
# test w/o prefix too
cfg = min_base_cfg | DictDefault(
fsdp_config={
"state_dict_type": "SHARDED_STATE_DICT",
},
save_safetensors=True,
)
with pytest.raises(
ValueError,
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
):
validate_config(cfg)
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"offload_params": True,
},
optimizer="adamw_8bit",
fsdp_version=1,
)
with pytest.raises(
ValueError, match="FSDP Offload not compatible with adamw_8bit"
):
validate_config(cfg)
def test_fsdp2_w_8bit_optim(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"offload_params": True,
},
optimizer="adamw_8bit",
fsdp_version=2,
)
with pytest.raises(
ValueError,
match="FSDP2 not compatible with adamw_8bit, use `adamw_torch_8bit` instead",
):
validate_config(cfg)
def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
load_in_8bit=True,
adapter="lora",
fsdp_config={
"cpu_ram_efficient_loading": True,
},
fsdp_version=2,
)
with pytest.raises(
ValueError,
match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.",
):
validate_config(cfg)
def test_fsdp_prefixes_removed(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
fsdp_config={
"fsdp_version": 2,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_reshard_after_forward": True,
}
)
cfg = validate_config(cfg)
assert cfg.fsdp_version == 2
assert cfg.fsdp_config.fsdp_version is None
for keys in cfg.fsdp_config.keys():
assert not keys.startswith("fsdp_")
assert cfg.fsdp_config.auto_wrap_policy == "TRANSFORMER_BASED_WRAP"
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
assert cfg.fsdp_config.reshard_after_forward is True
@pytest.mark.parametrize(
"rl",
[
"dpo",
"kto",
"orpo",
"ipo",
],
)
def test_fsdp2_dpo(self, min_base_cfg, rl):
cfg = min_base_cfg | DictDefault(
fsdp_version=2,
fsdp_config={
"reshard_after_forward": True,
},
rl=rl,
load_in_8bit=True,
adapter="lora",
remove_unused_columns=False,
)
with pytest.raises(
ValueError,
match="FSDP2 does not support load_in_8bit or load_in_4bit with ",
):
validate_config(cfg)