* allow profiling in mid-training rather from the start * simplify based on PR feedback * fix logic, improve saving at end, add tests
114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
"""
|
|
e2e gpu test for the pytorch profiler callback
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from axolotl.common.datasets import load_datasets
|
|
from axolotl.train import train
|
|
from axolotl.utils.config import normalize_config, validate_config
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
@pytest.fixture(name="profiler_base_cfg")
|
|
def fixture_profiler_base_cfg():
|
|
cfg = DictDefault(
|
|
base_model="HuggingFaceTB/SmolLM2-135M",
|
|
tokenizer_type="AutoTokenizer",
|
|
sequence_len=1024,
|
|
load_in_8bit=True,
|
|
adapter="lora",
|
|
lora_r=8,
|
|
lora_alpha=16,
|
|
lora_dropout=0.05,
|
|
lora_target_linear=True,
|
|
val_set_size=0.02,
|
|
special_tokens={"pad_token": "<|endoftext|>"},
|
|
datasets=[
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
},
|
|
],
|
|
num_epochs=1,
|
|
micro_batch_size=2,
|
|
gradient_accumulation_steps=1,
|
|
learning_rate=0.00001,
|
|
optimizer="adamw_torch_fused",
|
|
lr_scheduler="cosine",
|
|
)
|
|
return cfg
|
|
|
|
|
|
class TestProfiler:
|
|
"""
|
|
test cases for the pytorch profiler callback
|
|
"""
|
|
|
|
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
|
|
cfg = profiler_base_cfg | DictDefault(
|
|
output_dir=temp_dir,
|
|
max_steps=5,
|
|
profiler_steps=3,
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
normalize_config(cfg)
|
|
dataset_meta = load_datasets(cfg=cfg)
|
|
|
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
|
|
|
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
|
|
cfg = profiler_base_cfg | DictDefault(
|
|
output_dir=temp_dir,
|
|
max_steps=5,
|
|
profiler_steps=3,
|
|
profiler_steps_start=1,
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
normalize_config(cfg)
|
|
dataset_meta = load_datasets(cfg=cfg)
|
|
|
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
|
|
|
@pytest.mark.parametrize(
|
|
"profiler_steps_start",
|
|
[3, 5],
|
|
)
|
|
def test_profiler_saves_past_end(
|
|
self, profiler_base_cfg, temp_dir, profiler_steps_start
|
|
):
|
|
cfg = profiler_base_cfg | DictDefault(
|
|
output_dir=temp_dir,
|
|
max_steps=5,
|
|
profiler_steps=3,
|
|
profiler_steps_start=profiler_steps_start,
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
normalize_config(cfg)
|
|
dataset_meta = load_datasets(cfg=cfg)
|
|
|
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
|
|
|
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
|
|
cfg = profiler_base_cfg | DictDefault(
|
|
output_dir=temp_dir,
|
|
max_steps=5,
|
|
profiler_steps=3,
|
|
profiler_steps_start=6,
|
|
)
|
|
|
|
cfg = validate_config(cfg)
|
|
normalize_config(cfg)
|
|
dataset_meta = load_datasets(cfg=cfg)
|
|
|
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
assert not (Path(temp_dir) / "snapshot.pickle").exists()
|