allow profiling in mid-training rather from the start (#2899) [skip ci]
* allow profiling in mid-training rather from the start * simplify based on PR feedback * fix logic, improve saving at end, add tests
This commit is contained in:
@@ -112,13 +112,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.profiler_steps:
|
|
||||||
callbacks.append(
|
|
||||||
PytorchProfilerCallback(
|
|
||||||
steps_to_profile=self.cfg.profiler_steps,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.cfg.gc_steps:
|
if self.cfg.gc_steps:
|
||||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||||
|
|
||||||
@@ -145,6 +138,14 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||||
|
|
||||||
|
if self.cfg.profiler_steps:
|
||||||
|
callbacks.append(
|
||||||
|
PytorchProfilerCallback(
|
||||||
|
steps_to_profile=self.cfg.profiler_steps,
|
||||||
|
profiler_steps_start=self.cfg.profiler_steps_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
def get_post_trainer_create_callbacks(self, trainer):
|
||||||
|
|||||||
@@ -19,9 +19,27 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, steps_to_profile: int = 5):
|
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
||||||
self.steps_to_profile = steps_to_profile
|
# steps are 0 indexed, so to start at 0-th step, we start at beginning of first step,
|
||||||
if self.steps_to_profile:
|
# and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4]
|
||||||
|
self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1
|
||||||
|
if profiler_steps_start == 0:
|
||||||
|
# start recording memory allocations before everything is allocated, because if we start
|
||||||
|
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||||
|
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||||
|
enabled="all"
|
||||||
|
)
|
||||||
|
profiler_steps_start = -1
|
||||||
|
self.profiler_steps_start = profiler_steps_start
|
||||||
|
|
||||||
|
def on_step_begin( # pylint: disable=unused-argument
|
||||||
|
self,
|
||||||
|
args: TrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if state.global_step == self.profiler_steps_start:
|
||||||
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||||
enabled="all"
|
enabled="all"
|
||||||
)
|
)
|
||||||
@@ -33,7 +51,28 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
control: TrainerControl, # pylint: disable=unused-argument
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if state.global_step == self.steps_to_profile:
|
if state.global_step == self.profiler_steps_end:
|
||||||
|
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
|
||||||
|
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
|
||||||
|
dump(snapshot, fout)
|
||||||
|
|
||||||
|
# tell CUDA to stop recording memory allocations now
|
||||||
|
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||||
|
enabled=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_train_end( # pylint: disable=unused-argument
|
||||||
|
self,
|
||||||
|
args: TrainingArguments, # pylint: disable=unused-argument
|
||||||
|
state: TrainerState,
|
||||||
|
control: TrainerControl, # pylint: disable=unused-argument
|
||||||
|
**kwargs, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
# make sure to record if we happen to have more steps than steps to profile
|
||||||
|
if (
|
||||||
|
state.global_step >= self.profiler_steps_start
|
||||||
|
and state.global_step < self.profiler_steps_end
|
||||||
|
):
|
||||||
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
|
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
|
||||||
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
|
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
|
||||||
dump(snapshot, fout)
|
dump(snapshot, fout)
|
||||||
|
|||||||
@@ -741,6 +741,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz"
|
"description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
profiler_steps_start: int | None = Field(
|
||||||
|
default=0,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run."
|
||||||
|
},
|
||||||
|
)
|
||||||
include_tokens_per_second: bool | None = Field(
|
include_tokens_per_second: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
113
tests/e2e/test_profiler.py
Normal file
113
tests/e2e/test_profiler.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
e2e gpu test for the pytorch profiler callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.common.datasets import load_datasets
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="profiler_base_cfg")
|
||||||
|
def fixture_profiler_base_cfg():
|
||||||
|
cfg = DictDefault(
|
||||||
|
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||||
|
tokenizer_type="AutoTokenizer",
|
||||||
|
sequence_len=1024,
|
||||||
|
load_in_8bit=True,
|
||||||
|
adapter="lora",
|
||||||
|
lora_r=8,
|
||||||
|
lora_alpha=16,
|
||||||
|
lora_dropout=0.05,
|
||||||
|
lora_target_linear=True,
|
||||||
|
val_set_size=0.02,
|
||||||
|
special_tokens={"pad_token": "<|endoftext|>"},
|
||||||
|
datasets=[
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
num_epochs=1,
|
||||||
|
micro_batch_size=2,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
learning_rate=0.00001,
|
||||||
|
optimizer="adamw_torch_fused",
|
||||||
|
lr_scheduler="cosine",
|
||||||
|
)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
class TestProfiler:
|
||||||
|
"""
|
||||||
|
test cases for the pytorch profiler callback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
|
||||||
|
cfg = profiler_base_cfg | DictDefault(
|
||||||
|
output_dir=temp_dir,
|
||||||
|
max_steps=5,
|
||||||
|
profiler_steps=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||||
|
|
||||||
|
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
|
||||||
|
cfg = profiler_base_cfg | DictDefault(
|
||||||
|
output_dir=temp_dir,
|
||||||
|
max_steps=5,
|
||||||
|
profiler_steps=3,
|
||||||
|
profiler_steps_start=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"profiler_steps_start",
|
||||||
|
[3, 5],
|
||||||
|
)
|
||||||
|
def test_profiler_saves_past_end(
|
||||||
|
self, profiler_base_cfg, temp_dir, profiler_steps_start
|
||||||
|
):
|
||||||
|
cfg = profiler_base_cfg | DictDefault(
|
||||||
|
output_dir=temp_dir,
|
||||||
|
max_steps=5,
|
||||||
|
profiler_steps=3,
|
||||||
|
profiler_steps_start=profiler_steps_start,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||||
|
|
||||||
|
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
|
||||||
|
cfg = profiler_base_cfg | DictDefault(
|
||||||
|
output_dir=temp_dir,
|
||||||
|
max_steps=5,
|
||||||
|
profiler_steps=3,
|
||||||
|
profiler_steps_start=6,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
assert not (Path(temp_dir) / "snapshot.pickle").exists()
|
||||||
Reference in New Issue
Block a user