fix: move memory usage log to trainer.log (#2996) [skip ci]
This commit is contained in:
@@ -36,7 +36,6 @@ from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveModelOnFirstStepCallback,
|
||||
)
|
||||
@@ -141,8 +140,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
if self.cfg.save_first_step:
|
||||
callbacks.append(SaveModelOnFirstStepCallback())
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
|
||||
@@ -38,6 +38,8 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils import get_not_null
|
||||
from axolotl.utils.bench import get_gpu_memory_usage
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
@@ -560,6 +562,17 @@ class AxolotlTrainer(
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
|
||||
if is_main_process():
|
||||
# Add memory usage
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
logs["memory/max_memory_active"] = active
|
||||
logs["memory/max_memory_allocated"] = allocated
|
||||
logs["memory/device_memory_reserved"] = reserved
|
||||
except (ValueError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
return super().log(logs, start_time)
|
||||
|
||||
@@ -35,7 +35,6 @@ from transformers.trainer_utils import (
|
||||
from trl.models import unwrap_model_for_generation
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
@@ -93,39 +92,6 @@ class SaveBetterTransformerModelCallback(
|
||||
return control
|
||||
|
||||
|
||||
class GPUStatsCallback(
|
||||
TrainerCallback
|
||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
||||
"""Callback to track GPU utilization"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
args: TrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
) -> TrainerControl:
|
||||
if state.global_step > 0:
|
||||
if self.cfg.use_wandb and state.is_world_process_zero:
|
||||
try:
|
||||
active, allocated, reserved = get_gpu_memory_usage()
|
||||
wandb.log(
|
||||
{
|
||||
"memory/max_memory_active": active,
|
||||
"memory/max_memory_allocated": allocated,
|
||||
"memory/device_memory_reserved": reserved,
|
||||
},
|
||||
step=state.global_step,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
log_gpu_memory_usage(LOG, "", self.cfg.device)
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user