diff --git a/tests/e2e/patched/test_activation_checkpointing.py b/tests/e2e/patched/test_activation_checkpointing.py new file mode 100644 index 000000000..cbabab6fd --- /dev/null +++ b/tests/e2e/patched/test_activation_checkpointing.py @@ -0,0 +1,77 @@ +""" +E2E tests for activation checkpointing +""" + +import pytest +import transformers +from torch.utils.checkpoint import checkpoint + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from ..utils import check_model_output_exists + + +@pytest.fixture() +def fix_checkpoint_after_test(): + yield + transformers.modeling_utils.checkpoint = checkpoint + + +class TestActivationCheckpointing: + """ + E2E tests for activation checkpointing + """ + + def test_activation_checkpointing_offload( + self, + temp_dir, + fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name + ): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + "eos_token": "<|im_end|>", + }, + "datasets": [ + { + "chat_template": "chatml", + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": True, + "bf16": True, + "save_safetensors": True, + "gradient_checkpointing": "offload", + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg)