[feature] add pytorch profiling (#2182)

* add pytorch profiling

* kick off the profiler asap since things may get allcoated before train start

* document feature

* add url for visualizer [skip ci]
This commit is contained in:
Wing Lian
2024-12-16 12:38:43 -05:00
committed by GitHub
parent effc4dc409
commit 33090486d7
4 changed files with 56 additions and 0 deletions

View File

@@ -363,6 +363,10 @@ eval_table_size: # Approximate number of predictions sent to wandb depending on
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
# snapshots can be visualized @ https://pytorch.org/memory_viz
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)

View File

@@ -65,6 +65,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
@@ -1363,6 +1364,13 @@ class TrainerBuilderBase(abc.ABC):
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(
steps_to_profile=self.cfg.profiler_steps,
)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)

View File

@@ -0,0 +1,43 @@
"""
HF Trainer callback for creating pytorch profiling snapshots
"""
from pathlib import Path
from pickle import dump # nosec B403
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
class PytorchProfilerCallback(TrainerCallback):
"""
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
"""
def __init__(self, steps_to_profile: int = 5):
self.steps_to_profile = steps_to_profile
if self.steps_to_profile:
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled="all"
)
def on_step_end( # pylint: disable=unused-argument
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
if state.global_step == self.steps_to_profile:
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
dump(snapshot, fout)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
enabled=None
)

View File

@@ -762,6 +762,7 @@ class AxolotlInputConfig(
load_best_model_at_end: Optional[bool] = False
save_only_model: Optional[bool] = False
use_tensorboard: Optional[bool] = None
profiler_steps: Optional[int] = None
neftune_noise_alpha: Optional[float] = None