allow profiling in mid-training rather from the start (#2899) [skip ci]

* allow profiling in mid-training rather from the start

* simplify based on PR feedback

* fix logic, improve saving at end, add tests
This commit is contained in:
Wing Lian
2025-07-14 20:11:11 -04:00
committed by GitHub
parent 7dc3ac6cb3
commit 38359a8997
4 changed files with 170 additions and 11 deletions

View File

@@ -112,13 +112,6 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
@@ -145,6 +138,14 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(GPUStatsCallback(cfg=self.cfg))
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -19,9 +19,27 @@ class PytorchProfilerCallback(TrainerCallback):
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
"""
def __init__(self, steps_to_profile: int = 5):
self.steps_to_profile = steps_to_profile
if self.steps_to_profile:
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
# steps are 0 indexed, so to start at 0-th step, we start at beginning of first step,
# and finish at end of last step, so 5 steps_to_profile is steps [0, 1, 2, 3, 4]
self.profiler_steps_end = profiler_steps_start + steps_to_profile - 1
if profiler_steps_start == 0:
# start recording memory allocations before everything is allocated, because if we start
# at the beginning of step 0, we won't have any memory allocations in the traces
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all"
)
profiler_steps_start = -1
self.profiler_steps_start = profiler_steps_start
def on_step_begin( # pylint: disable=unused-argument
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
if state.global_step == self.profiler_steps_start:
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all"
)
@@ -33,7 +51,28 @@ class PytorchProfilerCallback(TrainerCallback):
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
if state.global_step == self.steps_to_profile:
if state.global_step == self.profiler_steps_end:
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
dump(snapshot, fout)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled=None
)
def on_train_end( # pylint: disable=unused-argument
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# make sure to record if we happen to have more steps than steps to profile
if (
state.global_step >= self.profiler_steps_start
and state.global_step < self.profiler_steps_end
):
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
dump(snapshot, fout)

View File

@@ -741,6 +741,12 @@ class AxolotlInputConfig(
"description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz"
},
)
profiler_steps_start: int | None = Field(
default=0,
json_schema_extra={
"description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run."
},
)
include_tokens_per_second: bool | None = Field(
default=None,
json_schema_extra={