Files
axolotl/tests/e2e/test_profiler.py
Wing Lian 38359a8997 allow profiling in mid-training rather from the start (#2899) [skip ci]
* allow profiling in mid-training rather from the start

* simplify based on PR feedback

* fix logic, improve saving at end, add tests
2025-07-14 20:11:11 -04:00

114 lines
3.1 KiB
Python

"""
e2e gpu test for the pytorch profiler callback
"""
from pathlib import Path
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="profiler_base_cfg")
def fixture_profiler_base_cfg():
cfg = DictDefault(
base_model="HuggingFaceTB/SmolLM2-135M",
tokenizer_type="AutoTokenizer",
sequence_len=1024,
load_in_8bit=True,
adapter="lora",
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_target_linear=True,
val_set_size=0.02,
special_tokens={"pad_token": "<|endoftext|>"},
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
num_epochs=1,
micro_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=0.00001,
optimizer="adamw_torch_fused",
lr_scheduler="cosine",
)
return cfg
class TestProfiler:
"""
test cases for the pytorch profiler callback
"""
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=1,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
@pytest.mark.parametrize(
"profiler_steps_start",
[3, 5],
)
def test_profiler_saves_past_end(
self, profiler_base_cfg, temp_dir, profiler_steps_start
):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=profiler_steps_start,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=6,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert not (Path(temp_dir) / "snapshot.pickle").exists()