fix: move memory usage log to trainer.log (#2996) [skip ci]

This commit is contained in:
NanoCode012
2025-08-02 00:21:43 +07:00
committed by GitHub
parent 02a37199ee
commit 41709822a7
3 changed files with 13 additions and 37 deletions

View File

@@ -36,7 +36,6 @@ from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
GCCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveModelOnFirstStepCallback,
)
@@ -141,8 +140,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.save_first_step:
callbacks.append(SaveModelOnFirstStepCallback())
callbacks.append(GPUStatsCallback(cfg=self.cfg))
if self.cfg.profiler_steps:
callbacks.append(
PytorchProfilerCallback(

View File

@@ -38,6 +38,8 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -560,6 +562,17 @@ class AxolotlTrainer(
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
if is_main_process():
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_memory_active"] = active
logs["memory/max_memory_allocated"] = allocated
logs["memory/device_memory_reserved"] = reserved
except (ValueError, FileNotFoundError):
pass
del self._stored_metrics[train_eval]
return super().log(logs, start_time)

View File

@@ -35,7 +35,6 @@ from transformers.trainer_utils import (
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage
from axolotl.utils.callbacks.perplexity import Perplexity
from axolotl.utils.distributed import (
barrier,
@@ -93,39 +92,6 @@ class SaveBetterTransformerModelCallback(
return control
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
def on_step_end(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
if state.global_step > 0:
if self.cfg.use_wandb and state.is_world_process_zero:
try:
active, allocated, reserved = get_gpu_memory_usage()
wandb.log(
{
"memory/max_memory_active": active,
"memory/max_memory_allocated": allocated,
"memory/device_memory_reserved": reserved,
},
step=state.global_step,
)
except ValueError:
pass
log_gpu_memory_usage(LOG, "", self.cfg.device)
return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""