diff --git a/src/axolotl/core/builders/base.py b/src/axolotl/core/builders/base.py index 32b228e21..dbdda7a7c 100644 --- a/src/axolotl/core/builders/base.py +++ b/src/axolotl/core/builders/base.py @@ -36,7 +36,6 @@ from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( GCCallback, - GPUStatsCallback, SaveAxolotlConfigtoWandBCallback, SaveModelOnFirstStepCallback, ) @@ -141,8 +140,6 @@ class TrainerBuilderBase(abc.ABC): if self.cfg.save_first_step: callbacks.append(SaveModelOnFirstStepCallback()) - callbacks.append(GPUStatsCallback(cfg=self.cfg)) - if self.cfg.profiler_steps: callbacks.append( PytorchProfilerCallback( diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index e3818ca7c..f739d19e9 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -38,6 +38,8 @@ from axolotl.core.trainers.utils import ( sanitize_kwargs_for_tagging, ) from axolotl.utils import get_not_null +from axolotl.utils.bench import get_gpu_memory_usage +from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -560,6 +562,17 @@ class AxolotlTrainer( # Add averaged stored metrics to logs for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() + + if is_main_process(): + # Add memory usage + try: + active, allocated, reserved = get_gpu_memory_usage() + logs["memory/max_memory_active"] = active + logs["memory/max_memory_allocated"] = allocated + logs["memory/device_memory_reserved"] = reserved + except (ValueError, FileNotFoundError): + pass + del self._stored_metrics[train_eval] return super().log(logs, start_time) diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 63799c734..d3f3126b5 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -35,7 +35,6 @@ from transformers.trainer_utils import ( from trl.models import unwrap_model_for_generation from axolotl.utils import is_comet_available, is_mlflow_available -from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity from axolotl.utils.distributed import ( barrier, @@ -93,39 +92,6 @@ class SaveBetterTransformerModelCallback( return control -class GPUStatsCallback( - TrainerCallback -): # pylint: disable=too-few-public-methods disable=unused-argument - """Callback to track GPU utilization""" - - def __init__(self, cfg): - self.cfg = cfg - - def on_step_end( - self, - args: TrainingArguments, # pylint: disable=unused-argument - state: TrainerState, - control: TrainerControl, - **kwargs, - ) -> TrainerControl: - if state.global_step > 0: - if self.cfg.use_wandb and state.is_world_process_zero: - try: - active, allocated, reserved = get_gpu_memory_usage() - wandb.log( - { - "memory/max_memory_active": active, - "memory/max_memory_allocated": allocated, - "memory/device_memory_reserved": reserved, - }, - step=state.global_step, - ) - except ValueError: - pass - log_gpu_memory_usage(LOG, "", self.cfg.device) - return control - - class LossWatchDogCallback(TrainerCallback): """Callback to track loss and stop training if loss is too high"""