diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index e80e905b8..4df010040 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -112,13 +112,6 @@ class TrainerBuilderBase(abc.ABC): plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model) ) - if self.cfg.profiler_steps: - callbacks.append( - PytorchProfilerCallback( - steps_to_profile=self.cfg.profiler_steps, - ) - ) - if self.cfg.gc_steps: callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps)) @@ -145,6 +138,14 @@ class TrainerBuilderBase(abc.ABC): callbacks.append(GPUStatsCallback(cfg=self.cfg)) + if self.cfg.profiler_steps: + callbacks.append( + PytorchProfilerCallback( + steps_to_profile=self.cfg.profiler_steps, + profiler_steps_start=self.cfg.profiler_steps_start, + ) + ) + return callbacks def get_post_trainer_create_callbacks(self, trainer): diff --git a/src/axolotl/utils/callbacks/profiler.py b/src/axolotl/utils/callbacks/profiler.py index 36604813f..d26b7f9dd 100644 --- a/src/axolotl/utils/callbacks/profiler.py +++ b/src/axolotl/utils/callbacks/profiler.py @@ -19,9 +19,27 @@ class PytorchProfilerCallback(TrainerCallback): PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps. """ - def __init__(self, steps_to_profile: int = 5): - self.steps_to_profile = steps_to_profile - if self.steps_to_profile: + def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0): + # steps are 0 indexed, so to start at 0-th step, we start at beginning of first step, + # and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4] + self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1 + if profiler_steps_start == 0: + # start recording memory allocations before everything is allocated, because if we start + # at the beginning of step 0, we won't have any memory allocations in the traces + torch.cuda.memory._record_memory_history( # pylint: disable=protected-access + enabled="all" + ) + profiler_steps_start = -1 + self.profiler_steps_start = profiler_steps_start + + def on_step_begin( # pylint: disable=unused-argument + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + if state.global_step == self.profiler_steps_start: torch.cuda.memory._record_memory_history( # pylint: disable=protected-access enabled="all" ) @@ -33,7 +51,28 @@ class PytorchProfilerCallback(TrainerCallback): control: TrainerControl, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ): - if state.global_step == self.steps_to_profile: + if state.global_step == self.profiler_steps_end: + snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access + with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout: + dump(snapshot, fout) + + # tell CUDA to stop recording memory allocations now + torch.cuda.memory._record_memory_history( # pylint: disable=protected-access + enabled=None + ) + + def on_train_end( # pylint: disable=unused-argument + self, + args: TrainingArguments, # pylint: disable=unused-argument + state: TrainerState, + control: TrainerControl, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument + ): + # make sure to record if we happen to have more steps than steps to profile + if ( + state.global_step >= self.profiler_steps_start + and state.global_step < self.profiler_steps_end + ): snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout: dump(snapshot, fout) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index f757cc5b0..1726feb67 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -741,6 +741,12 @@ class AxolotlInputConfig( "description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz" }, ) + profiler_steps_start: int | None = Field( + default=0, + json_schema_extra={ + "description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run." + }, + ) include_tokens_per_second: bool | None = Field( default=None, json_schema_extra={ diff --git a/tests/e2e/test_profiler.py b/tests/e2e/test_profiler.py new file mode 100644 index 000000000..ab273b981 --- /dev/null +++ b/tests/e2e/test_profiler.py @@ -0,0 +1,113 @@ +""" +e2e gpu test for the pytorch profiler callback +""" + +from pathlib import Path + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="profiler_base_cfg") +def fixture_profiler_base_cfg(): + cfg = DictDefault( + base_model="HuggingFaceTB/SmolLM2-135M", + tokenizer_type="AutoTokenizer", + sequence_len=1024, + load_in_8bit=True, + adapter="lora", + lora_r=8, + lora_alpha=16, + lora_dropout=0.05, + lora_target_linear=True, + val_set_size=0.02, + special_tokens={"pad_token": "<|endoftext|>"}, + datasets=[ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + num_epochs=1, + micro_batch_size=2, + gradient_accumulation_steps=1, + learning_rate=0.00001, + optimizer="adamw_torch_fused", + lr_scheduler="cosine", + ) + return cfg + + +class TestProfiler: + """ + test cases for the pytorch profiler callback + """ + + def test_profiler_saves(self, profiler_base_cfg, temp_dir): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "snapshot.pickle").exists() + + def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + profiler_steps_start=1, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "snapshot.pickle").exists() + + @pytest.mark.parametrize( + "profiler_steps_start", + [3, 5], + ) + def test_profiler_saves_past_end( + self, profiler_base_cfg, temp_dir, profiler_steps_start + ): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + profiler_steps_start=profiler_steps_start, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "snapshot.pickle").exists() + + def test_profiler_never_started(self, profiler_base_cfg, temp_dir): + cfg = profiler_base_cfg | DictDefault( + output_dir=temp_dir, + max_steps=5, + profiler_steps=3, + profiler_steps_start=6, + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert not (Path(temp_dir) / "snapshot.pickle").exists()