Compare commits
18 Commits
torch_tens
...
revert-290
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f6d917a99 | ||
|
|
10ba1622f7 | ||
|
|
d320ef6199 | ||
|
|
354eaaf0d3 | ||
|
|
a061446540 | ||
|
|
cd079b5536 | ||
|
|
5cc16040a8 | ||
|
|
38359a8997 | ||
|
|
7dc3ac6cb3 | ||
|
|
99187cd208 | ||
|
|
aa684122f1 | ||
|
|
ca4d4ef793 | ||
|
|
37edbe4999 | ||
|
|
e581c15d40 | ||
|
|
af92151a7b | ||
|
|
80dc4c261a | ||
|
|
7ccbbd8e77 | ||
|
|
5081db7f8a |
7
.github/workflows/multi-gpu-e2e.yml
vendored
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -33,6 +33,13 @@ jobs:
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
|
||||
15
.github/workflows/nightlies.yml
vendored
15
.github/workflows/nightlies.yml
vendored
@@ -12,11 +12,16 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -60,15 +65,15 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
@@ -276,6 +276,7 @@ website:
|
||||
- docs/torchao.qmd
|
||||
- docs/custom_integrations.qmd
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
29
docs/gradient_checkpointing.qmd
Normal file
29
docs/gradient_checkpointing.qmd
Normal file
@@ -0,0 +1,29 @@
|
||||
---
|
||||
title: Gradient Checkpointing and Activation Offloading
|
||||
---
|
||||
|
||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||
models by reducing the memory footprint and improving computational efficiency.
|
||||
|
||||
### Enabling Gradient Checkpointing
|
||||
|
||||
```yaml
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
### Enabling Activation Offloading
|
||||
|
||||
```yaml
|
||||
gradient_checkpointing: true # required for activation offloading
|
||||
activation_offloading: true
|
||||
```
|
||||
|
||||
Activation offloading variants:
|
||||
|
||||
The default `activation_offloading: true` offloads activations to CPU and uses CUDA streams
|
||||
to overlap the communications and computations when offloading.
|
||||
|
||||
The `activation_offloading: legacy` naively offloads activations to CPU and without additional optimizations.
|
||||
|
||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||
@@ -6,19 +6,19 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.10
|
||||
liger-kernel==0.6.0
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub==0.32.2
|
||||
peft==0.15.2
|
||||
transformers==4.53.1
|
||||
huggingface_hub>=0.33.0
|
||||
peft==0.16.0
|
||||
transformers==4.53.2
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.8.1
|
||||
datasets==3.6.0
|
||||
datasets==4.0.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.2
|
||||
trl==0.19.1
|
||||
hf_xet==1.1.2
|
||||
|
||||
optimum==1.16.2
|
||||
|
||||
8
setup.py
8
setup.py
@@ -73,9 +73,9 @@ def parse_requirements(extras_require_map):
|
||||
extras_require_map["vllm"] = ["vllm>=0.9.0"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
_install_requires.append("xformers==0.0.29.post3")
|
||||
# since we only support 2.6.0+cu126
|
||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
@@ -121,7 +121,7 @@ extras_require = {
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.1",
|
||||
"deepspeed==0.17.2",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -16,7 +16,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
migrate_fsdp_config,
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
validate_config,
|
||||
@@ -227,7 +226,6 @@ def load_cfg(
|
||||
},
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
prepare_optim_env(cfg)
|
||||
prepare_opinionated_env(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""CLI to run preprocessing of a dataset."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -95,6 +96,7 @@ def do_cli(
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.is_preprocess = True
|
||||
parser = transformers.HfArgumentParser(PreprocessCliArgs)
|
||||
|
||||
@@ -37,7 +37,6 @@ def do_vllm_serve(
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
patch_vllm_worker()
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
@@ -47,6 +46,9 @@ def do_vllm_serve(
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
data_parallel_size = (
|
||||
cli_args.get("data_parallel_size") or cfg.vllm.data_parallel_size
|
||||
)
|
||||
host = cli_args.get("host") or cfg.vllm.host
|
||||
port = cli_args.get("port") or cfg.vllm.port
|
||||
gpu_memory_utilization = (
|
||||
@@ -68,6 +70,7 @@ def do_vllm_serve(
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
|
||||
@@ -112,13 +112,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
@@ -145,6 +138,14 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||
)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
@@ -418,6 +419,9 @@ class TrainerBuilderBase(abc.ABC):
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_args_kwargs["torch_compile_backend"] = (
|
||||
@@ -426,8 +430,16 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_accelerator_config(self, training_args_kwargs: dict):
|
||||
if self.cfg.accelerator_config:
|
||||
training_args_kwargs["accelerator_config"] = self.cfg.accelerator_config
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.gradient_checkpointing:
|
||||
if self.cfg.activation_offloading is True:
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
training_args_kwargs["activation_offloading"] = True
|
||||
elif self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
@@ -510,5 +522,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
self._configure_scheduler(training_args_kwargs)
|
||||
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
|
||||
self._configure_torch_compile(training_args_kwargs)
|
||||
self._configure_accelerator_config(training_args_kwargs)
|
||||
|
||||
return training_args_kwargs, trainer_kwargs
|
||||
|
||||
@@ -310,11 +310,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.neftune_noise_alpha
|
||||
)
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs["accelerator_config"] = (
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.image_size:
|
||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||
if self.cfg.image_resize_algorithm:
|
||||
|
||||
@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
@@ -48,6 +49,7 @@ class AxolotlTrainer(
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
ActivationOffloadingMixin,
|
||||
Trainer,
|
||||
):
|
||||
"""Extend the base Trainer for axolotl helpers"""
|
||||
@@ -75,18 +77,6 @@ class AxolotlTrainer(
|
||||
if self.args.orpo_alpha:
|
||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.torch_compile:
|
||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||
256
|
||||
)
|
||||
model = torch.compile(
|
||||
model,
|
||||
backend=self.args.torch_compile_backend,
|
||||
mode=self.args.torch_compile_mode,
|
||||
)
|
||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
|
||||
def _create_multipack_sampler(
|
||||
self, base_sampler: Sampler, dataset: Dataset
|
||||
) -> MultipackBatchSampler:
|
||||
|
||||
@@ -14,6 +14,7 @@ from axolotl.core.trainers.grpo.trainer import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -41,9 +42,18 @@ class GRPOStrategy:
|
||||
return grpo_args_kwargs
|
||||
|
||||
trl: TRLConfig = cfg.trl # type: ignore
|
||||
vllm_cfg: VllmConfig = cfg.vllm # type: ignore
|
||||
|
||||
if trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||
grpo_args_kwargs["vllm_mode"] = trl.vllm_mode
|
||||
if trl.vllm_mode == "colocate":
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
vllm_cfg.gpu_memory_utilization
|
||||
)
|
||||
grpo_args_kwargs["vllm_tensor_parallel_size"] = (
|
||||
vllm_cfg.tensor_parallel_size
|
||||
)
|
||||
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host # type: ignore[attr-defined]
|
||||
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port # type: ignore[attr-defined]
|
||||
if trl.vllm_server_timeout:
|
||||
|
||||
@@ -59,42 +59,6 @@ class AxolotlGRPOTrainer(
|
||||
|
||||
_tag_names = ["trl", "grpo", "axolotl"]
|
||||
|
||||
def get_train_dataloader(self):
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
|
||||
train_dataset = self.train_dataset
|
||||
data_collator = self.data_collator
|
||||
if isinstance(train_dataset, datasets.Dataset):
|
||||
train_dataset = self._remove_unused_columns(
|
||||
train_dataset, description="training"
|
||||
)
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(
|
||||
data_collator, description="training"
|
||||
)
|
||||
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size
|
||||
* self.args.steps_per_generation, # < this is the change
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers,
|
||||
}
|
||||
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
|
||||
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
"""Extend the base GRPOTrainer for sequence parallelism handling"""
|
||||
@@ -252,7 +216,11 @@ class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
if not is_eval:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
|
||||
# Create the dataloader
|
||||
dataloader = DataLoader(dataset, **dataloader_params)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# pylint: disable=unused-import
|
||||
# flake8: noqa
|
||||
|
||||
from .activation_checkpointing import ActivationOffloadingMixin
|
||||
from .checkpoints import CheckpointSaveMixin
|
||||
from .optimizer import OptimizerMixin
|
||||
from .packing import PackingMixin
|
||||
|
||||
37
src/axolotl/core/trainers/mixins/activation_checkpointing.py
Normal file
37
src/axolotl/core/trainers/mixins/activation_checkpointing.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Trainer mixin for activation checkpointing w offloading
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
|
||||
from torch import nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
apply_activation_checkpointing,
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from transformers import GradientCheckpointingLayer, Trainer
|
||||
from trl.models.activation_offloading import get_act_offloading_ctx_manager
|
||||
|
||||
|
||||
class ActivationOffloadingMixin(Trainer):
|
||||
"""
|
||||
Trainer mixin class for activation checkpointing w offloading
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.args.activation_offloading:
|
||||
self.activation_offload_context = get_act_offloading_ctx_manager(
|
||||
self.model, use_streams=True
|
||||
)
|
||||
else:
|
||||
self.activation_offload_context = contextlib.nullcontext()
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
with self.activation_offload_context:
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
|
||||
def ac_wrap_hf_model(model: nn.Module, **kwargs):
|
||||
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
|
||||
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
|
||||
@@ -217,6 +217,11 @@ class AxolotlTrainingMixins:
|
||||
},
|
||||
)
|
||||
|
||||
activation_offloading: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Use activation offloading with CUDA streams for training."},
|
||||
)
|
||||
|
||||
# multi-modal section
|
||||
|
||||
image_size: int | tuple[int, int] | None = field(
|
||||
|
||||
@@ -6,15 +6,21 @@ from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
from transformers import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
class TransformersKwargs(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
|
||||
except ImportError:
|
||||
from transformers.utils.generic import ( # type: ignore[no-redef]
|
||||
TransformersKwargs,
|
||||
)
|
||||
|
||||
|
||||
def kldiv_forward_llama_like(
|
||||
@@ -33,7 +39,7 @@ def kldiv_forward_llama_like(
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
||||
**kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
|
||||
@@ -198,12 +198,22 @@ class ModelLoader:
|
||||
):
|
||||
self.model = self.model.merge_and_unload()
|
||||
|
||||
self._apply_activation_checkpointing()
|
||||
self._resize_token_embeddings()
|
||||
self._adjust_model_config()
|
||||
self._configure_embedding_dtypes()
|
||||
self._configure_qat()
|
||||
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
|
||||
|
||||
def _apply_activation_checkpointing(self):
|
||||
if self.cfg.activation_offloading is True:
|
||||
from axolotl.core.trainers.mixins.activation_checkpointing import (
|
||||
ac_wrap_hf_model,
|
||||
)
|
||||
|
||||
# ^^ importing this at the module level breaks plugins
|
||||
ac_wrap_hf_model(self.model)
|
||||
|
||||
def _resize_token_embeddings(self):
|
||||
"""Resize token embeddings if needed."""
|
||||
embeddings_len = (
|
||||
|
||||
@@ -7,7 +7,6 @@ import importlib.util
|
||||
from functools import cached_property
|
||||
|
||||
import addict
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
@@ -168,28 +167,19 @@ class PatchManager:
|
||||
|
||||
def _apply_gradient_checkpointing_patches(self):
|
||||
"""Apply patches for gradient checkpointing."""
|
||||
if self.cfg.gradient_checkpointing in ["unsloth", "offload"]:
|
||||
if (
|
||||
self.cfg.gradient_checkpointing
|
||||
and self.cfg.activation_offloading == "legacy"
|
||||
):
|
||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||
CheckpointFunctionWithCPUOffload,
|
||||
hf_grad_checkpoint_offload_wrapper,
|
||||
)
|
||||
|
||||
if (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
and "use_reentrant" in self.cfg.gradient_checkpointing_kwargs
|
||||
and self.cfg.gradient_checkpointing_kwargs["use_reentrant"] is False
|
||||
):
|
||||
transformers.modeling_utils.checkpoint = (
|
||||
hf_grad_checkpoint_offload_wrapper
|
||||
)
|
||||
else:
|
||||
transformers.modeling_utils.checkpoint.CheckpointFunction = (
|
||||
CheckpointFunctionWithCPUOffload
|
||||
)
|
||||
torch.utils.checkpoint.CheckpointFunction = (
|
||||
CheckpointFunctionWithCPUOffload
|
||||
)
|
||||
if self.cfg.gradient_checkpointing == "offload_disk":
|
||||
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_offload_wrapper
|
||||
elif (
|
||||
self.cfg.gradient_checkpointing
|
||||
and self.cfg.activation_offloading == "offload_disk"
|
||||
):
|
||||
from axolotl.monkeypatch.gradient_checkpointing import (
|
||||
hf_grad_checkpoint_disk_offload_wrapper,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from functools import partial
|
||||
from packaging import version
|
||||
|
||||
from axolotl.monkeypatch.gradient_checkpointing.offload_cpu import ( # noqa: F401
|
||||
CheckpointFunctionWithCPUOffload,
|
||||
CPU_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
from axolotl.monkeypatch.gradient_checkpointing.offload_disk import (
|
||||
|
||||
@@ -14,18 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils.checkpoint import (
|
||||
_get_autocast_kwargs,
|
||||
_get_device_module,
|
||||
_infer_device_type,
|
||||
check_backward_validity,
|
||||
detach_variable,
|
||||
get_device_states,
|
||||
set_device_states,
|
||||
)
|
||||
|
||||
@@ -76,153 +69,3 @@ class CPU_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
) + (
|
||||
None,
|
||||
) * len(ctx.args)
|
||||
|
||||
|
||||
# Copyright 2025 Snowflake Inc.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# https://github.com/snowflakedb/ArcticTraining/blob/main/arctic_training/monkey_patches.py
|
||||
class CheckpointFunctionWithCPUOffload(torch.autograd.Function):
|
||||
"""
|
||||
This is a torch/utils/checkpoint.py CheckpointFunction monkey patch that offloads the first tensor to cpu during forward and back to cuda during backward. This allows significant memory savings when using a very long seqlen. e.g. for llama 8b at 100k it's 24GB saved per gpu: `((100_000*4096)*2*32/2**30)`
|
||||
In the case of a very long seqlen 100k+ the copying to/from cpu overhead is not big, because dense quadratic attention compute will dominate.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.preserve_rng_state = preserve_rng_state
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
ctx.device_type = _infer_device_type(*args)
|
||||
ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
|
||||
ctx.device_type
|
||||
)
|
||||
if preserve_rng_state:
|
||||
ctx.fwd_cpu_state = torch.get_rng_state()
|
||||
# Don't eagerly initialize the cuda context by accident.
|
||||
# (If the user intends that the context is initialized later, within their
|
||||
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||
# we have no way to anticipate this will happen before we run the function.)
|
||||
ctx.had_device_in_fwd = False
|
||||
device_module = _get_device_module(ctx.device_type)
|
||||
if getattr(device_module, "_initialized", False):
|
||||
ctx.had_device_in_fwd = True
|
||||
ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
|
||||
|
||||
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
|
||||
# to be filled out during the backward.
|
||||
ctx.inputs = []
|
||||
ctx.tensor_indices = []
|
||||
tensor_inputs = []
|
||||
# x = None
|
||||
for i, arg in enumerate(args):
|
||||
if torch.is_tensor(arg):
|
||||
# cpu-offload
|
||||
# we don't want the 2nd tensor - usually it's a shared 4D attn mask which is huge [seq,seq]
|
||||
# upstream could accept a list of arg indices to offload
|
||||
if i == 0:
|
||||
# print(f"{arg.shape=}")
|
||||
ctx.x_device = arg.device
|
||||
ctx.x_requires_grad = arg.requires_grad
|
||||
t = arg.detach().cpu()
|
||||
else:
|
||||
t = arg
|
||||
tensor_inputs.append(t)
|
||||
ctx.tensor_indices.append(i)
|
||||
ctx.inputs.append(None)
|
||||
else:
|
||||
ctx.inputs.append(arg)
|
||||
|
||||
ctx.save_for_backward(*tensor_inputs)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if (
|
||||
not torch.autograd._is_checkpoint_valid() # pylint: disable=protected-access
|
||||
):
|
||||
raise RuntimeError(
|
||||
"When use_reentrant=True, torch.utils.checkpoint is incompatible"
|
||||
" with .grad() or passing an `inputs` parameter to .backward()."
|
||||
" To resolve this error, you can either set use_reentrant=False,"
|
||||
" or call .backward() without passing the `inputs` argument."
|
||||
)
|
||||
# Copy the list to avoid modifying original list.
|
||||
inputs = list(ctx.inputs)
|
||||
tensor_indices = ctx.tensor_indices
|
||||
tensors = ctx.saved_tensors
|
||||
|
||||
# Fill in inputs with appropriate saved tensors.
|
||||
for i, idx in enumerate(tensor_indices):
|
||||
if i == 0:
|
||||
t = (
|
||||
tensors[i]
|
||||
.to(ctx.x_device)
|
||||
.detach()
|
||||
.requires_grad_(ctx.x_requires_grad)
|
||||
)
|
||||
else:
|
||||
t = tensors[i]
|
||||
inputs[idx] = t
|
||||
|
||||
# Stash the surrounding rng state, and mimic the state that was
|
||||
# present at this time during forward. Restore the surrounding state
|
||||
# when we're done.
|
||||
rng_devices = []
|
||||
if ctx.preserve_rng_state and ctx.had_device_in_fwd:
|
||||
rng_devices = ctx.fwd_devices
|
||||
with torch.random.fork_rng(
|
||||
devices=rng_devices,
|
||||
enabled=ctx.preserve_rng_state,
|
||||
device_type=ctx.device_type,
|
||||
):
|
||||
if ctx.preserve_rng_state:
|
||||
torch.set_rng_state(ctx.fwd_cpu_state)
|
||||
if ctx.had_device_in_fwd:
|
||||
if has_device_type:
|
||||
# newer pytorch (as early as 2.7)
|
||||
set_device_states(
|
||||
ctx.fwd_devices,
|
||||
ctx.fwd_device_states,
|
||||
device_type=ctx.device_type,
|
||||
)
|
||||
else:
|
||||
# older pytorch (at least 2.4)
|
||||
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
|
||||
device_autocast_ctx = (
|
||||
torch.amp.autocast(
|
||||
device_type=ctx.device_type, **ctx.device_autocast_kwargs
|
||||
)
|
||||
if torch.amp.is_autocast_available(ctx.device_type)
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
|
||||
# run backward() with only tensor that requires grad
|
||||
outputs_with_grad = []
|
||||
args_with_grad = []
|
||||
for i in range(len(outputs)): # pylint: disable=consider-using-enumerate
|
||||
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
|
||||
outputs_with_grad.append(outputs[i])
|
||||
args_with_grad.append(args[i])
|
||||
if len(outputs_with_grad) == 0:
|
||||
raise RuntimeError(
|
||||
"none of output has requires_grad=True, this checkpoint() is not necessary"
|
||||
)
|
||||
torch.autograd.backward(outputs_with_grad, args_with_grad)
|
||||
grads = tuple(
|
||||
inp.grad if isinstance(inp, torch.Tensor) else None
|
||||
for inp in detached_inputs
|
||||
)
|
||||
|
||||
return (None, None) + grads
|
||||
|
||||
@@ -379,6 +379,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
Public method that can handle either a single prompt or a batch of prompts.
|
||||
"""
|
||||
|
||||
def _remove_none_values(obj):
|
||||
"""
|
||||
Remove null from a dictionary-like obj or list.
|
||||
These can appear due to Dataset loading causing schema merge.
|
||||
See https://github.com/axolotl-ai-cloud/axolotl/pull/2909
|
||||
"""
|
||||
if hasattr(obj, "items"):
|
||||
return {
|
||||
k: _remove_none_values(v) for k, v in obj.items() if v is not None
|
||||
}
|
||||
if isinstance(obj, list):
|
||||
return [_remove_none_values(elem) for elem in obj]
|
||||
return obj
|
||||
|
||||
prompt = _remove_none_values(prompt)
|
||||
|
||||
if not self.is_prompt_batched(prompt) or not self.supports_batched:
|
||||
return self._tokenize_single_prompt(prompt)
|
||||
|
||||
|
||||
@@ -224,6 +224,9 @@ def execute_training(
|
||||
# torch.set_default_dtype(torch.bfloat16)
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train(cfg, trainer.model)
|
||||
|
||||
|
||||
def save_trained_model(
|
||||
cfg: DictDefault,
|
||||
@@ -510,6 +513,9 @@ def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) ->
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
return (
|
||||
trainer,
|
||||
model,
|
||||
@@ -541,9 +547,6 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -566,6 +569,4 @@ def train(
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -841,21 +841,35 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
class GCCallback(TrainerCallback):
|
||||
"""Callback to garbage collect torch cache"""
|
||||
|
||||
def __init__(self, gc_steps=None):
|
||||
self.gc_steps = gc_steps
|
||||
def __init__(self, gc_steps: int | None = -1):
|
||||
self.gc_steps: int = gc_steps or -1
|
||||
self.next_gc_on_begin_step: int = -1
|
||||
|
||||
def _gc(self):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def on_step_begin(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if self.next_gc_on_begin_step == state.global_step:
|
||||
self._gc()
|
||||
|
||||
def on_step_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if control.should_evaluate:
|
||||
# automatically GC before evals so the eval memory spike from the CEL doesn't OOM the trainer
|
||||
self._gc()
|
||||
# also GC on the start of the next step after the eval
|
||||
self.next_gc_on_begin_step = state.global_step + 1
|
||||
elif self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
|
||||
self._gc()
|
||||
|
||||
def on_epoch_end(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
self._gc()
|
||||
|
||||
|
||||
def colab_inference_post_train_callback(trainer: Trainer):
|
||||
|
||||
@@ -19,9 +19,27 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||
"""
|
||||
|
||||
def __init__(self, steps_to_profile: int = 5):
|
||||
self.steps_to_profile = steps_to_profile
|
||||
if self.steps_to_profile:
|
||||
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
||||
# steps are 0 indexed, so to start at 0-th step, we start at beginning of first step,
|
||||
# and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4]
|
||||
self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1
|
||||
if profiler_steps_start == 0:
|
||||
# start recording memory allocations before everything is allocated, because if we start
|
||||
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||
enabled="all"
|
||||
)
|
||||
profiler_steps_start = -1
|
||||
self.profiler_steps_start = profiler_steps_start
|
||||
|
||||
def on_step_begin( # pylint: disable=unused-argument
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if state.global_step == self.profiler_steps_start:
|
||||
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||
enabled="all"
|
||||
)
|
||||
@@ -33,7 +51,28 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if state.global_step == self.steps_to_profile:
|
||||
if state.global_step == self.profiler_steps_end:
|
||||
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
|
||||
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
|
||||
dump(snapshot, fout)
|
||||
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||
enabled=None
|
||||
)
|
||||
|
||||
def on_train_end( # pylint: disable=unused-argument
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# make sure to record if we happen to have more steps than steps to profile
|
||||
if (
|
||||
state.global_step >= self.profiler_steps_start
|
||||
and state.global_step < self.profiler_steps_end
|
||||
):
|
||||
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
|
||||
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
|
||||
dump(snapshot, fout)
|
||||
|
||||
@@ -314,16 +314,3 @@ def prepare_plugins(cfg):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
for plugin_name in cfg["plugins"]:
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
|
||||
# TODO @SalmanMohammadi remove this function in 0.12
|
||||
def migrate_fsdp_config(cfg):
|
||||
if cfg.get("fsdp_config"):
|
||||
fsdp_config_keys = cfg.fsdp_config.keys()
|
||||
if "fsdp_version" in fsdp_config_keys:
|
||||
cfg.fsdp_version = cfg.fsdp_config.pop("fsdp_version")
|
||||
|
||||
for key in list(fsdp_config_keys):
|
||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||
cfg.fsdp_config[key.replace("fsdp_", "")] = cfg.fsdp_config[key]
|
||||
del cfg.fsdp_config[key]
|
||||
|
||||
@@ -497,3 +497,131 @@ class HFMistralTokenizer:
|
||||
return [
|
||||
self._mistral.instruct_tokenizer.tokenizer.id_to_piece(id) for id in ids
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str],
|
||||
add_special_tokens: bool = True,
|
||||
padding: bool | str = False,
|
||||
truncation: bool = False,
|
||||
max_length: int | None = None,
|
||||
return_tensors: str | None = None,
|
||||
**kwargs,
|
||||
) -> dict[str, list[int] | np.ndarray | Tensor]:
|
||||
"""
|
||||
Tokenize text and return a dictionary with input_ids and attention_mask.
|
||||
|
||||
Args:
|
||||
text: Input text string or list of strings to tokenize.
|
||||
add_special_tokens: Whether to add special tokens (BOS/EOS).
|
||||
padding: Whether to pad sequences. Can be True, False, "longest", or "max_length".
|
||||
truncation: Whether to truncate sequences to max_length.
|
||||
max_length: Maximum sequence length for truncation/padding.
|
||||
return_tensors: Return format ("pt" for PyTorch, "np" for NumPy, None for lists).
|
||||
|
||||
Returns:
|
||||
Dictionary with "input_ids" and "attention_mask" keys.
|
||||
"""
|
||||
# if kwargs passed, raise error
|
||||
if kwargs:
|
||||
raise ValueError(
|
||||
f"Unsupported kwargs: {kwargs}. Please create an issue on GitHub."
|
||||
)
|
||||
|
||||
# `np` can work with inhomogeneous shapes but let's not support it until needed.
|
||||
if (
|
||||
isinstance(text, list)
|
||||
and len(text) > 1
|
||||
and return_tensors in ("pt", "np")
|
||||
and padding is False
|
||||
and truncation is False
|
||||
):
|
||||
raise ValueError(
|
||||
"return_tensors='pt' or 'np' requires padding or truncation."
|
||||
)
|
||||
|
||||
# Handle single string input
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
|
||||
# Encode all texts
|
||||
# TODO: figure out how to parallelize this
|
||||
batch_input_ids = []
|
||||
for single_text in text:
|
||||
input_ids = self.encode(single_text, add_special_tokens=add_special_tokens)
|
||||
|
||||
# Handle truncation
|
||||
if truncation and max_length is not None and len(input_ids) > max_length:
|
||||
input_ids = input_ids[:max_length]
|
||||
|
||||
batch_input_ids.append(input_ids)
|
||||
|
||||
# Create attention masks (1 for real tokens, 0 for padding)
|
||||
attention_masks = [[1] * len(input_ids) for input_ids in batch_input_ids]
|
||||
|
||||
# Handle padding
|
||||
if padding in (True, "longest"):
|
||||
# Pad to longest sequence in batch
|
||||
max_len = max(len(input_ids) for input_ids in batch_input_ids)
|
||||
|
||||
for i, input_ids in enumerate(batch_input_ids):
|
||||
pad_length = max_len - len(input_ids)
|
||||
if pad_length > 0:
|
||||
if self.padding_side == "right":
|
||||
batch_input_ids[i] = (
|
||||
input_ids + [self.pad_token_id] * pad_length
|
||||
)
|
||||
attention_masks[i] = attention_masks[i] + [0] * pad_length
|
||||
else: # left padding
|
||||
batch_input_ids[i] = [
|
||||
self.pad_token_id
|
||||
] * pad_length + input_ids
|
||||
attention_masks[i] = [0] * pad_length + attention_masks[i]
|
||||
|
||||
elif padding == "max_length":
|
||||
if max_length is None:
|
||||
raise ValueError(
|
||||
"max_length must be specified when padding='max_length'"
|
||||
)
|
||||
|
||||
for i, input_ids in enumerate(batch_input_ids):
|
||||
pad_length = max_length - len(input_ids)
|
||||
if pad_length > 0:
|
||||
if self.padding_side == "right":
|
||||
batch_input_ids[i] = (
|
||||
input_ids + [self.pad_token_id] * pad_length
|
||||
)
|
||||
attention_masks[i] = attention_masks[i] + [0] * pad_length
|
||||
else: # left padding
|
||||
batch_input_ids[i] = [
|
||||
self.pad_token_id
|
||||
] * pad_length + input_ids
|
||||
attention_masks[i] = [0] * pad_length + attention_masks[i]
|
||||
|
||||
# Prepare result
|
||||
result = {}
|
||||
|
||||
# Handle return tensor format
|
||||
if return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
result["input_ids"] = torch.tensor(batch_input_ids, dtype=torch.long)
|
||||
result["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
|
||||
elif return_tensors == "np":
|
||||
result["input_ids"] = np.array(batch_input_ids, dtype=np.int64)
|
||||
result["attention_mask"] = np.array(attention_masks, dtype=np.int64)
|
||||
elif return_tensors is None:
|
||||
result["input_ids"] = batch_input_ids
|
||||
result["attention_mask"] = attention_masks
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported return_tensors='{return_tensors}'. "
|
||||
"Only 'pt' and 'np' are supported."
|
||||
)
|
||||
|
||||
# If single input, return single sequences (not batched)
|
||||
if len(text) == 1 and return_tensors is None:
|
||||
result["input_ids"] = result["input_ids"][0]
|
||||
result["attention_mask"] = result["attention_mask"][0]
|
||||
|
||||
return result
|
||||
|
||||
@@ -320,7 +320,12 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
gc_steps: int | None = None
|
||||
gc_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled)."
|
||||
},
|
||||
)
|
||||
|
||||
bf16: Literal["auto"] | bool | None = Field(
|
||||
default="auto",
|
||||
@@ -360,6 +365,12 @@ class AxolotlInputConfig(
|
||||
"description": "Additional kwargs to pass to the trainer for gradient checkpointing"
|
||||
},
|
||||
)
|
||||
activation_offloading: Literal["legacy", "disk"] | bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
|
||||
},
|
||||
)
|
||||
|
||||
unfrozen_parameters: list[str] | None = None
|
||||
|
||||
@@ -573,6 +584,12 @@ class AxolotlInputConfig(
|
||||
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
|
||||
},
|
||||
)
|
||||
deepcompile: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use deepcompile for faster training with deepspeed"
|
||||
},
|
||||
)
|
||||
fsdp: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "FSDP configuration"},
|
||||
@@ -618,7 +635,12 @@ class AxolotlInputConfig(
|
||||
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
|
||||
},
|
||||
)
|
||||
|
||||
tensor_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||
},
|
||||
)
|
||||
special_tokens: SpecialTokensConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -730,6 +752,12 @@ class AxolotlInputConfig(
|
||||
"description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz"
|
||||
},
|
||||
)
|
||||
profiler_steps_start: int | None = Field(
|
||||
default=0,
|
||||
json_schema_extra={
|
||||
"description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run."
|
||||
},
|
||||
)
|
||||
include_tokens_per_second: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -1143,72 +1171,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version(cls, data):
|
||||
fsdp_config = data.get("fsdp_config", {})
|
||||
if fsdp_config and str(data.get("fsdp_version")) != "2":
|
||||
LOG.info(
|
||||
"FSDP1 will be deprecated in an upcoming release of Axolotl."
|
||||
"We recommend that you use FSDP version 2 for better performance and compatibility. "
|
||||
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
|
||||
"For more details on migrating your config. "
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp2_base_model_quant_ram_efficient_loading(cls, data):
|
||||
fsdp_config = data.get("fsdp_config")
|
||||
if fsdp_config and data.get("fsdp_version") == 2:
|
||||
if fsdp_config.get("cpu_ram_efficient_loading") and (
|
||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
||||
):
|
||||
raise ValueError(
|
||||
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
|
||||
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp2_base_model_quant_dpo(cls, data):
|
||||
if data.get("fsdp_version") == 2 and data.get("rl") in [
|
||||
RLType.DPO,
|
||||
RLType.KTO,
|
||||
RLType.ORPO,
|
||||
RLType.IPO,
|
||||
]:
|
||||
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
||||
raise ValueError(
|
||||
"FSDP2 does not support load_in_8bit or load_in_4bit with DPO. Please use DeepSpeed or set `fsdp_version` to 1."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
if fsdp_config.get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
for key, _ in fsdp_config.items():
|
||||
if key.startswith("fsdp_"):
|
||||
LOG.warning_once(
|
||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def default_dataloader_opts(cls, data):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Pydantic models for TRL trainer configuration"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -27,6 +29,12 @@ class TRLConfig(BaseModel):
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training."},
|
||||
)
|
||||
vllm_mode: Literal["server", "colocate"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "VLLM mode to use, one of 'server' or 'colocate'"
|
||||
},
|
||||
)
|
||||
vllm_server_host: str | None = Field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
json_schema_extra={"description": "Host of the vLLM server to connect to."},
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""Module with validation methods for config pydantic model."""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import (
|
||||
field_validator,
|
||||
@@ -12,6 +15,8 @@ from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
@@ -748,43 +753,181 @@ class OptimizationValidationMixin:
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and "8bit" in data.get("optimizer", "")
|
||||
and data.get("fsdp_config")
|
||||
and data["fsdp_config"].get("fsdp_offload_params")
|
||||
and str(data["fsdp_config"].get("fsdp_version")) != "2"
|
||||
):
|
||||
raise ValueError(
|
||||
f"FSDP Offload not compatible with {data.get('optimizer')}"
|
||||
def check_fsdp_version(cls, data):
|
||||
fsdp_config = data.get("fsdp_config", {})
|
||||
if fsdp_config and str(data.get("fsdp_version")) != "2":
|
||||
LOG.info(
|
||||
"FSDP1 will be deprecated in an upcoming release of Axolotl."
|
||||
"We recommend that you use FSDP version 2 for better performance and compatibility. "
|
||||
"Please see this link for more details: https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp "
|
||||
"For more details on migrating your config. "
|
||||
)
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and "8bit" in data.get("optimizer", "")
|
||||
and data.get("fsdp_config")
|
||||
and str(data["fsdp_config"].get("fsdp_version")) == "2"
|
||||
):
|
||||
if data.get("optimizer", "") in ["adamw_8bit", "adamw_bnb_8bit"]:
|
||||
# CUDA ops errors with bnb 8bit optimizer + FSDP2
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp2_base_model_quant_ram_efficient_loading(self):
|
||||
fsdp_config = self.fsdp_config if hasattr(self, "fsdp_config") else None
|
||||
fsdp_version = self.fsdp_version if hasattr(self, "fsdp_version") else None
|
||||
load_in_8bit = self.load_in_8bit if hasattr(self, "load_in_8bit") else None
|
||||
load_in_4bit = self.load_in_4bit if hasattr(self, "load_in_4bit") else None
|
||||
if fsdp_config and fsdp_version == 2:
|
||||
if fsdp_config.get("cpu_ram_efficient_loading") and (
|
||||
load_in_8bit or load_in_4bit
|
||||
):
|
||||
raise ValueError(
|
||||
f"FSDP2 not compatible with {data.get('optimizer')}, use `adamw_torch_8bit` instead"
|
||||
"FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading. Please do one of the following: use DeepSpeed, "
|
||||
"set fsdp_version to 1, or disable cpu_ram_efficient_loading."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp2_base_model_quant_rl(cls, data):
|
||||
if data.get("fsdp_version") == 2 and data.get("rl") in [
|
||||
RLType.DPO,
|
||||
RLType.KTO,
|
||||
RLType.ORPO,
|
||||
RLType.IPO,
|
||||
]:
|
||||
if data.get("load_in_8bit") or data.get("load_in_4bit"):
|
||||
raise ValueError(
|
||||
f"FSDP2 does not support load_in_8bit or load_in_4bit with {data.get('rl')}. Please use DeepSpeed or set `fsdp_version` to 1."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
|
||||
def check_fsdp_version_in_fsdp_config(cls, data):
|
||||
if data.get("fsdp_config"):
|
||||
if data.get("fsdp_config", {}).get("fsdp_version"):
|
||||
LOG.warning(
|
||||
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
|
||||
"Please configure `fsdp_version` as a top-level field."
|
||||
)
|
||||
data["fsdp_version"] = data.get("fsdp_config").pop("fsdp_version")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_config_kwargs_prefix(cls, data):
|
||||
if fsdp_config := data.get("fsdp_config"):
|
||||
should_fix = False
|
||||
for key, _ in fsdp_config.items():
|
||||
if key.startswith("fsdp_"):
|
||||
should_fix = True
|
||||
LOG.warning_once(
|
||||
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
|
||||
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
|
||||
)
|
||||
if should_fix:
|
||||
update_fsdp_config = {}
|
||||
for key, value in fsdp_config.items():
|
||||
if key.startswith("fsdp_") and key != "fsdp_version":
|
||||
update_fsdp_config[key.replace("fsdp_", "")] = value
|
||||
else:
|
||||
update_fsdp_config[key] = value
|
||||
data["fsdp_config"] = update_fsdp_config
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp_offload_w_8bit_optimizer(self):
|
||||
if (
|
||||
data.get("fsdp_config")
|
||||
and data.get("save_safetensors")
|
||||
and data.get("fsdp_config")
|
||||
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
|
||||
hasattr(self, "fsdp_config")
|
||||
and self.fsdp_config
|
||||
and self.optimizer
|
||||
and "8bit" in self.optimizer.value
|
||||
and self.fsdp_config["offload_params"]
|
||||
and str(self.fsdp_version) != "2"
|
||||
):
|
||||
raise ValueError(
|
||||
f"FSDP Offload not compatible with {str(self.optimizer.value)}"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp2_w_8bit_optimizer(self):
|
||||
if (
|
||||
hasattr(self, "fsdp_config")
|
||||
and self.fsdp_config
|
||||
and self.optimizer
|
||||
and "8bit" in self.optimizer.value
|
||||
and str(self.fsdp_version) == "2"
|
||||
):
|
||||
if self.optimizer in ["adamw_8bit", "adamw_bnb_8bit"]:
|
||||
# CUDA ops errors with bnb 8bit optimizer + FSDP2
|
||||
raise ValueError(
|
||||
f"FSDP2 not compatible with {self.optimizer.value}, use `adamw_torch_8bit` instead"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_fsdp_sharded_state_dict_w_safetensors(self):
|
||||
if (
|
||||
hasattr(self, "fsdp_config")
|
||||
and self.fsdp_config
|
||||
and hasattr(self, "save_safetensors")
|
||||
and self.save_safetensors
|
||||
and self.fsdp_config.get("state_dict_type", "") == "SHARDED_STATE_DICT"
|
||||
and str(getattr(self, "fsdp_version", "1")) != "2"
|
||||
):
|
||||
raise ValueError(
|
||||
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tensor_parallel_size_update_ds_json(cls, data):
|
||||
tensor_parallel_size = data.get("tensor_parallel_size")
|
||||
if tensor_parallel_size is not None and tensor_parallel_size > 1:
|
||||
if not data.get("deepspeed"):
|
||||
raise ValueError(
|
||||
"Tensor parallelism (TP) is only supported with DeepSpeed"
|
||||
)
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
should_save = False
|
||||
if "tensor_parallel" not in ds_config:
|
||||
ds_config["tensor_parallel"] = {"autotp_size": tensor_parallel_size}
|
||||
should_save = True
|
||||
if (
|
||||
"gather_16bit_weights_on_model_save"
|
||||
not in ds_config["zero_optimization"]
|
||||
):
|
||||
ds_config["zero_optimization"][
|
||||
"gather_16bit_weights_on_model_save"
|
||||
] = True
|
||||
should_save = True
|
||||
if should_save:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "autotp_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "autotp_ds.json")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deepcompile(cls, data):
|
||||
deepcompile = data.get("deepcompile")
|
||||
if deepcompile:
|
||||
if not data.get("deepspeed"):
|
||||
raise ValueError("DeepCompile is only supported with DeepSpeed")
|
||||
with open(data.get("deepspeed"), "r", encoding="utf-8") as ds_fin:
|
||||
ds_config = json.load(ds_fin)
|
||||
if "compile" not in ds_config:
|
||||
ds_config["compile"] = {"deepcompile": True}
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(
|
||||
Path(temp_dir) / "deepcompile_ds.json", "w", encoding="utf-8"
|
||||
) as ds_fout:
|
||||
json.dump(ds_config, ds_fout, indent=4)
|
||||
data["deepspeed"] = str(Path(temp_dir) / "deepcompile_ds.json")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -932,6 +1075,28 @@ class ModelCompatibilityValidationMixin:
|
||||
self.gradient_checkpointing = "offload"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_gradient_checkpointing_w_offload(self):
|
||||
if self.gradient_checkpointing == "offload":
|
||||
LOG.warning(
|
||||
"`offload` is deprecated for gradient_checkpointing, use `activation_offloading: true`"
|
||||
)
|
||||
self.gradient_checkpointing = True
|
||||
self.activation_offloading = True
|
||||
if self.gradient_checkpointing == "offload_disk":
|
||||
LOG.warning(
|
||||
"`offload_disk` is deprecated for gradient_checkpointing, use `activation_offloading: disk`"
|
||||
)
|
||||
self.gradient_checkpointing = True
|
||||
self.activation_offloading = "disk"
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_activation_offloading_wo_gc(self):
|
||||
if self.activation_offloading and not self.gradient_checkpointing:
|
||||
raise ValueError("activation_offloading requires gradient_checkpointing")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_better_transformers(self):
|
||||
if self.flash_optimum is True:
|
||||
@@ -1019,6 +1184,12 @@ class ComplexValidationMixin:
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tensor_parallel_size(self):
|
||||
if not self.tensor_parallel_size:
|
||||
self.tensor_parallel_size = 1
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_sequence_parallel_degree(self):
|
||||
if not self.sequence_parallel_degree:
|
||||
|
||||
@@ -18,6 +18,10 @@ class VllmConfig(BaseModel):
|
||||
default=None,
|
||||
json_schema_extra={"description": "Tensor parallel size for VLLM"},
|
||||
)
|
||||
data_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Data parallel size for VLLM"},
|
||||
)
|
||||
gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
|
||||
@@ -443,6 +443,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
@@ -481,7 +482,10 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
)
|
||||
)
|
||||
if cfg.dataloader_drop_last:
|
||||
@@ -508,6 +512,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
* cfg.sequence_parallel_degree
|
||||
* cfg.tensor_parallel_size
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
@@ -546,7 +551,10 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||
# to model load.
|
||||
if int(os.environ.get("WORLD_SIZE", "1")) == 1:
|
||||
if (
|
||||
int(os.environ.get("WORLD_SIZE", "1")) == 1
|
||||
and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
|
||||
):
|
||||
os.environ["WORLD_SIZE"] = "1" # force it in case not set
|
||||
os.environ["LOCAL_RANK"] = "0" # force it in case not set
|
||||
os.environ["RANK"] = os.environ.get("LOCAL_RANK", "0")
|
||||
|
||||
@@ -22,6 +22,8 @@ from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import (
|
||||
enable_hf_offline,
|
||||
hf_offline_context,
|
||||
@@ -539,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture(name="min_base_cfg")
|
||||
def fixture_min_base_cfg():
|
||||
return DictDefault(
|
||||
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||
learning_rate=1e-3,
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
micro_batch_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
||||
|
||||
@@ -65,6 +65,7 @@ def fixture_base_cfg():
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"sequence_parallel_degree": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
|
||||
@@ -141,6 +141,7 @@ def recursive_kill(process: subprocess.Popen):
|
||||
os.kill(process.pid, 9)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky vllm tests in modal")
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
|
||||
@@ -391,7 +391,10 @@ class TestMultiGPULlama:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_state_dict_type",
|
||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||
[
|
||||
"FULL_STATE_DICT",
|
||||
# "SHARDED_STATE_DICT", # not supported since intermediate checkpoints fail with fsdp1
|
||||
],
|
||||
)
|
||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -413,7 +416,8 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 3,
|
||||
"save_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
@@ -597,7 +601,7 @@ class TestMultiGPULlama:
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
@@ -707,7 +711,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.4, "Train Loss (%s) is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
58
tests/e2e/test_preprocess.py
Normal file
58
tests/e2e/test_preprocess.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""E2E Test the preprocess cli"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class TestPreprocess:
|
||||
"""test cases for preprocess"""
|
||||
|
||||
def test_w_deepspeed(self, temp_dir):
|
||||
"""make sure preproces doesn't choke when using deepspeed in the config"""
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"preprocess",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
assert (Path(temp_dir) / "last_run_prepared").exists()
|
||||
113
tests/e2e/test_profiler.py
Normal file
113
tests/e2e/test_profiler.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
e2e gpu test for the pytorch profiler callback
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="profiler_base_cfg")
|
||||
def fixture_profiler_base_cfg():
|
||||
cfg = DictDefault(
|
||||
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||
tokenizer_type="AutoTokenizer",
|
||||
sequence_len=1024,
|
||||
load_in_8bit=True,
|
||||
adapter="lora",
|
||||
lora_r=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
lora_target_linear=True,
|
||||
val_set_size=0.02,
|
||||
special_tokens={"pad_token": "<|endoftext|>"},
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
num_epochs=1,
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=0.00001,
|
||||
optimizer="adamw_torch_fused",
|
||||
lr_scheduler="cosine",
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
class TestProfiler:
|
||||
"""
|
||||
test cases for the pytorch profiler callback
|
||||
"""
|
||||
|
||||
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=1,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"profiler_steps_start",
|
||||
[3, 5],
|
||||
)
|
||||
def test_profiler_saves_past_end(
|
||||
self, profiler_base_cfg, temp_dir, profiler_steps_start
|
||||
):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=profiler_steps_start,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=6,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert not (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Tests for chat template prompt strategy with schema unification for none fields
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import StrategyLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="messages_w_tools")
|
||||
def fixture_messages_w_tools():
|
||||
jsons = """
|
||||
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
""".strip().split(
|
||||
"\n"
|
||||
)
|
||||
rows = [json.loads(row) for row in jsons]
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_prompt_strategy")
|
||||
def qwen3_chat_template_strategy(qwen3_tokenizer):
|
||||
cfg = DictDefault(
|
||||
sequence_len=2048,
|
||||
chat_template="qwen3",
|
||||
eot_tokens=["<|im_end|>"],
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
type="chat_template",
|
||||
)
|
||||
load = StrategyLoader()
|
||||
strat = load(qwen3_tokenizer, cfg, ds_cfg)
|
||||
return strat
|
||||
|
||||
|
||||
class TestSchemaUnification:
|
||||
"""
|
||||
Test class on handling null fields for tool calling
|
||||
"""
|
||||
|
||||
def test_schema_unification_single_prompt(
|
||||
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||
):
|
||||
for row in messages_w_tools:
|
||||
inputs = qwen3_prompt_strategy.tokenize_prompt(row)
|
||||
decoded = qwen3_tokenizer.decode(inputs["input_ids"])
|
||||
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||
assert '"message": null' not in tool_call
|
||||
assert '"theta": null' not in tool_call
|
||||
|
||||
def test_schema_unification_batched(
|
||||
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||
):
|
||||
rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)
|
||||
for row in rows:
|
||||
decoded = qwen3_tokenizer.decode(row["input_ids"])
|
||||
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||
assert '"message": null' not in tool_call
|
||||
assert '"theta": null' not in tool_call
|
||||
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.mistral_tokenizer import HFMistralTokenizer
|
||||
|
||||
|
||||
@@ -748,5 +750,100 @@ def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
|
||||
assert "Not the same number of function calls and responses" in str(e)
|
||||
|
||||
|
||||
def test_magistral_tokenizer_call_method(
|
||||
magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
|
||||
):
|
||||
"""Test the __call__ method behavior matches HuggingFace standards"""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
hf_tokenizer = deepcopy(llama3_tokenizer)
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
test_text = "Hello, how are you?"
|
||||
batch_texts = ["Hello world", "How are you?"]
|
||||
|
||||
# Test single string with return_tensors=None
|
||||
hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
|
||||
mistral_result: dict[str, list[int]] = magistral_tokenizer(
|
||||
test_text, return_tensors=None
|
||||
)
|
||||
|
||||
assert isinstance(mistral_result, dict)
|
||||
assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
|
||||
assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
|
||||
assert isinstance(
|
||||
mistral_result["attention_mask"], type(hf_result["attention_mask"])
|
||||
)
|
||||
assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
|
||||
assert np.all(mistral_result["attention_mask"])
|
||||
assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
|
||||
|
||||
# Test single string with return_tensors='pt'
|
||||
hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
|
||||
mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
test_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Check structure and types
|
||||
assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
|
||||
assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
|
||||
|
||||
# Check shapes match (don't compare token dimension)
|
||||
assert len(hf_result_pt["input_ids"].shape) == len(
|
||||
mistral_result_pt["input_ids"].shape
|
||||
)
|
||||
assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
|
||||
assert (
|
||||
mistral_result_pt["attention_mask"].shape
|
||||
== mistral_result_pt["input_ids"].shape
|
||||
)
|
||||
assert torch.all(mistral_result_pt["attention_mask"] == 1)
|
||||
|
||||
# Test batch input with padding
|
||||
hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
|
||||
batch_texts, return_tensors="pt", padding=True
|
||||
)
|
||||
mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
batch_texts, return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
# Check batch behavior
|
||||
assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
|
||||
assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
|
||||
assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
|
||||
assert torch.any(
|
||||
mistral_batch["attention_mask"][0] == 0
|
||||
) # padding in shorter sequence
|
||||
assert torch.all(
|
||||
mistral_batch["attention_mask"][1] == 1
|
||||
) # no padding in longer sequence
|
||||
|
||||
# Test numpy tensors
|
||||
mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
|
||||
test_text, return_tensors="np"
|
||||
)
|
||||
assert isinstance(mistral_result_np["input_ids"], np.ndarray)
|
||||
assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
|
||||
|
||||
# Test consistency with encode()
|
||||
encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
|
||||
called: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
test_text, return_tensors="pt"
|
||||
)
|
||||
assert encoded == called["input_ids"][0].tolist()
|
||||
|
||||
# Test Error handling
|
||||
with pytest.raises(ValueError, match="Unsupported kwargs"):
|
||||
magistral_tokenizer(test_text, unsupported_param=True)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
|
||||
):
|
||||
magistral_tokenizer(batch_texts, return_tensors="pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -6,9 +6,9 @@ import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.utils.config import (
|
||||
migrate_fsdp_config,
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
validate_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -27,6 +27,13 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"learning_rate": 0.0001,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -97,7 +104,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
|
||||
def test_migrate_fsdp_config(self):
|
||||
"""Test basic FSDP config migration with and without fsdp_version"""
|
||||
cfg_with_version = DictDefault(
|
||||
cfg_with_version = self._get_base_cfg() | DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_version": 2,
|
||||
@@ -109,7 +116,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg_with_version)
|
||||
cfg_with_version = validate_config(cfg_with_version)
|
||||
|
||||
self.assertEqual(cfg_with_version.fsdp_version, 2)
|
||||
self.assertEqual(
|
||||
@@ -125,7 +132,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
self.assertNotIn("fsdp_version", cfg_with_version.fsdp_config)
|
||||
self.assertNotIn("version", cfg_with_version.fsdp_config)
|
||||
|
||||
cfg_without_version = DictDefault(
|
||||
cfg_without_version = self._get_base_cfg() | DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_auto_wrap_policy": "SIZE_BASED_WRAP",
|
||||
@@ -135,7 +142,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg_without_version)
|
||||
cfg_without_version = validate_config(cfg_without_version)
|
||||
|
||||
self.assertNotIn("fsdp_version", cfg_without_version)
|
||||
self.assertEqual(
|
||||
@@ -149,26 +156,25 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
|
||||
def test_migrate_fsdp_config_no_fsdp_config(self):
|
||||
"""Test that function doesn't crash when no fsdp_config is present"""
|
||||
cfg = DictDefault({"some_other_config": "value"})
|
||||
cfg = self._get_base_cfg()
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
|
||||
self.assertNotIn("fsdp_config", cfg)
|
||||
self.assertNotIn("fsdp_version", cfg)
|
||||
self.assertEqual(cfg.some_other_config, "value")
|
||||
|
||||
def test_migrate_fsdp_config_empty_fsdp_config(self):
|
||||
"""Test migration with empty fsdp_config"""
|
||||
cfg = DictDefault({"fsdp_config": {}})
|
||||
cfg = self._get_base_cfg() | DictDefault({"fsdp_config": {}})
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
|
||||
self.assertNotIn("fsdp_version", cfg)
|
||||
self.assertEqual(cfg.fsdp_config, {})
|
||||
|
||||
def test_migrate_fsdp_config_mixed_keys(self):
|
||||
"""Test migration with a mix of fsdp_ and non-fsdp_ keys"""
|
||||
cfg = DictDefault(
|
||||
cfg = self._get_base_cfg() | DictDefault(
|
||||
{
|
||||
"fsdp_config": {
|
||||
"fsdp_version": 1,
|
||||
@@ -180,7 +186,7 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
migrate_fsdp_config(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
|
||||
self.assertEqual(cfg.fsdp_version, 1)
|
||||
self.assertEqual(cfg.fsdp_config.state_dict_type, "FULL_STATE_DICT")
|
||||
|
||||
@@ -7,21 +7,16 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="train_base_cfg")
|
||||
def fixture_train_base_cfg():
|
||||
return DictDefault(
|
||||
base_model="gpt2",
|
||||
learning_rate=1e-3,
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
sequence_len=2048,
|
||||
sample_packing=True,
|
||||
num_epochs=1,
|
||||
def fixture_train_base_cfg(min_base_cfg):
|
||||
return (
|
||||
DictDefault(
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=4,
|
||||
sequence_len=2048,
|
||||
sample_packing=True,
|
||||
num_epochs=1,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
|
||||
|
||||
|
||||
139
tests/utils/schemas/validation/test_fsdp.py
Normal file
139
tests/utils/schemas/validation/test_fsdp.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
tests for pydantic fsdp validation
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestFSDPValidation:
|
||||
"""
|
||||
test class for pydantic fsdp validation
|
||||
"""
|
||||
|
||||
def test_fsdp_version_in_fsdp_config(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_version": 2,
|
||||
},
|
||||
)
|
||||
cfg = validate_config(
|
||||
cfg,
|
||||
)
|
||||
assert cfg.fsdp_version == 2
|
||||
assert cfg.fsdp_config.fsdp_version is None
|
||||
|
||||
def test_fsdp_sharded_state_dict_safetensors(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
save_safetensors=True,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
# test w/o prefix too
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"state_dict_type": "SHARDED_STATE_DICT",
|
||||
},
|
||||
save_safetensors=True,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP SHARDED_STATE_DICT not compatible with save_safetensors",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp_offload_w_8bit_optim(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"offload_params": True,
|
||||
},
|
||||
optimizer="adamw_8bit",
|
||||
fsdp_version=1,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError, match="FSDP Offload not compatible with adamw_8bit"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp2_w_8bit_optim(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"offload_params": True,
|
||||
},
|
||||
optimizer="adamw_8bit",
|
||||
fsdp_version=2,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP2 not compatible with adamw_8bit, use `adamw_torch_8bit` instead",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp2_w_cpu_ram_efficient_loading(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
load_in_8bit=True,
|
||||
adapter="lora",
|
||||
fsdp_config={
|
||||
"cpu_ram_efficient_loading": True,
|
||||
},
|
||||
fsdp_version=2,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP2 does not support load_in_8bit or load_in_4bit with cpu_ram_efficient_loading.",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_fsdp_prefixes_removed(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_config={
|
||||
"fsdp_version": 2,
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_reshard_after_forward": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
assert cfg.fsdp_version == 2
|
||||
assert cfg.fsdp_config.fsdp_version is None
|
||||
for keys in cfg.fsdp_config.keys():
|
||||
assert not keys.startswith("fsdp_")
|
||||
assert cfg.fsdp_config.auto_wrap_policy == "TRANSFORMER_BASED_WRAP"
|
||||
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
|
||||
assert cfg.fsdp_config.reshard_after_forward is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rl",
|
||||
[
|
||||
"dpo",
|
||||
"kto",
|
||||
"orpo",
|
||||
"ipo",
|
||||
],
|
||||
)
|
||||
def test_fsdp2_dpo(self, min_base_cfg, rl):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
fsdp_version=2,
|
||||
fsdp_config={
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
rl=rl,
|
||||
load_in_8bit=True,
|
||||
adapter="lora",
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="FSDP2 does not support load_in_8bit or load_in_4bit with ",
|
||||
):
|
||||
validate_config(cfg)
|
||||
Reference in New Issue
Block a user