[feature] add pytorch profiling (#2182)
* add pytorch profiling * kick off the profiler asap since things may get allcoated before train start * document feature * add url for visualizer [skip ci]
This commit is contained in:
@@ -363,6 +363,10 @@ eval_table_size: # Approximate number of predictions sent to wandb depending on
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
|
||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||
# see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information
|
||||
# snapshots can be visualized @ https://pytorch.org/memory_viz
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ from axolotl.utils.callbacks import (
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
@@ -1363,6 +1364,13 @@ class TrainerBuilderBase(abc.ABC):
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
|
||||
43
src/axolotl/utils/callbacks/profiler.py
Normal file
43
src/axolotl/utils/callbacks/profiler.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
HF Trainer callback for creating pytorch profiling snapshots
|
||||
"""
|
||||
from pathlib import Path
|
||||
from pickle import dump # nosec B403
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
TrainerState,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
|
||||
class PytorchProfilerCallback(TrainerCallback):
|
||||
"""
|
||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||
"""
|
||||
|
||||
def __init__(self, steps_to_profile: int = 5):
|
||||
self.steps_to_profile = steps_to_profile
|
||||
if self.steps_to_profile:
|
||||
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||
enabled="all"
|
||||
)
|
||||
|
||||
def on_step_end( # pylint: disable=unused-argument
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if state.global_step == self.steps_to_profile:
|
||||
snapshot = torch.cuda.memory._snapshot() # pylint: disable=protected-access
|
||||
with open(Path(args.output_dir) / "snapshot.pickle", "wb") as fout:
|
||||
dump(snapshot, fout)
|
||||
|
||||
# tell CUDA to stop recording memory allocations now
|
||||
torch.cuda.memory._record_memory_history( # pylint: disable=protected-access
|
||||
enabled=None
|
||||
)
|
||||
@@ -762,6 +762,7 @@ class AxolotlInputConfig(
|
||||
load_best_model_at_end: Optional[bool] = False
|
||||
save_only_model: Optional[bool] = False
|
||||
use_tensorboard: Optional[bool] = None
|
||||
profiler_steps: Optional[int] = None
|
||||
|
||||
neftune_noise_alpha: Optional[float] = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user