improve vram use w gradient checkpointing (#1167) [skip ci]

This commit is contained in:
Wing Lian
2024-01-22 19:48:22 -05:00
committed by GitHub
parent b8e5603467
commit 802f9667a2

View File

@@ -159,6 +159,13 @@ def normalize_config(cfg):
if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]
if (
cfg.gradient_checkpointing
and cfg.unfrozen_parameters is None
and cfg.gradient_checkpointing_kwargs is None
):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
log_gpu_memory_usage(LOG, "baseline", cfg.device)