fix: move memory usage log to trainer.log (#2996) [skip ci]
This commit is contained in:
@@ -36,7 +36,6 @@ from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
GCCallback,
|
GCCallback,
|
||||||
GPUStatsCallback,
|
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveModelOnFirstStepCallback,
|
SaveModelOnFirstStepCallback,
|
||||||
)
|
)
|
||||||
@@ -141,8 +140,6 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.save_first_step:
|
if self.cfg.save_first_step:
|
||||||
callbacks.append(SaveModelOnFirstStepCallback())
|
callbacks.append(SaveModelOnFirstStepCallback())
|
||||||
|
|
||||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
|
||||||
|
|
||||||
if self.cfg.profiler_steps:
|
if self.cfg.profiler_steps:
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
PytorchProfilerCallback(
|
PytorchProfilerCallback(
|
||||||
|
|||||||
@@ -38,6 +38,8 @@ from axolotl.core.trainers.utils import (
|
|||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
from axolotl.utils import get_not_null
|
from axolotl.utils import get_not_null
|
||||||
|
from axolotl.utils.bench import get_gpu_memory_usage
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
@@ -560,6 +562,17 @@ class AxolotlTrainer(
|
|||||||
# Add averaged stored metrics to logs
|
# Add averaged stored metrics to logs
|
||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
# Add memory usage
|
||||||
|
try:
|
||||||
|
active, allocated, reserved = get_gpu_memory_usage()
|
||||||
|
logs["memory/max_memory_active"] = active
|
||||||
|
logs["memory/max_memory_allocated"] = allocated
|
||||||
|
logs["memory/device_memory_reserved"] = reserved
|
||||||
|
except (ValueError, FileNotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
|
|
||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from transformers.trainer_utils import (
|
|||||||
from trl.models import unwrap_model_for_generation
|
from trl.models import unwrap_model_for_generation
|
||||||
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import get_gpu_memory_usage, log_gpu_memory_usage
|
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
@@ -93,39 +92,6 @@ class SaveBetterTransformerModelCallback(
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class GPUStatsCallback(
|
|
||||||
TrainerCallback
|
|
||||||
): # pylint: disable=too-few-public-methods disable=unused-argument
|
|
||||||
"""Callback to track GPU utilization"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.cfg = cfg
|
|
||||||
|
|
||||||
def on_step_end(
|
|
||||||
self,
|
|
||||||
args: TrainingArguments, # pylint: disable=unused-argument
|
|
||||||
state: TrainerState,
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs,
|
|
||||||
) -> TrainerControl:
|
|
||||||
if state.global_step > 0:
|
|
||||||
if self.cfg.use_wandb and state.is_world_process_zero:
|
|
||||||
try:
|
|
||||||
active, allocated, reserved = get_gpu_memory_usage()
|
|
||||||
wandb.log(
|
|
||||||
{
|
|
||||||
"memory/max_memory_active": active,
|
|
||||||
"memory/max_memory_allocated": allocated,
|
|
||||||
"memory/device_memory_reserved": reserved,
|
|
||||||
},
|
|
||||||
step=state.global_step,
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
log_gpu_memory_usage(LOG, "", self.cfg.device)
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
class LossWatchDogCallback(TrainerCallback):
|
class LossWatchDogCallback(TrainerCallback):
|
||||||
"""Callback to track loss and stop training if loss is too high"""
|
"""Callback to track loss and stop training if loss is too high"""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user