diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d2f2e414c..2e4252123 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import gc import json import math import os @@ -800,7 +801,14 @@ class AxolotlTrainer( with open(tokens_state_path, "w", encoding="utf-8") as f: json.dump(tokens_state, f) - return super()._save_checkpoint(model, trial, **kwargs) + result = super()._save_checkpoint(model, trial, **kwargs) + + # Reclaim VRAM held by the FSDP full-state-dict gather. + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return result # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged def _save(self, output_dir: Optional[str] = None, state_dict=None): diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index e086885bb..20c863ee5 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio import copy import functools +import gc import os import sys @@ -161,6 +162,7 @@ def get_state_dict(self, model, unwrap=True): state_dict = {} sharded_state_dict = model.state_dict() + is_rank_zero = torch.distributed.get_rank() == 0 for param_name, param in sharded_state_dict.items(): if param.is_cpu: param = param.to(torch.device("cuda")) @@ -168,9 +170,20 @@ def get_state_dict(self, model, unwrap=True): if isinstance(param, DTensor): param = param.full_tensor() - if torch.distributed.get_rank() == 0: + if is_rank_zero: state_dict[param_name] = param.cpu() + # Drop the GPU-resident gathered tensor before the next iteration + # allocates the next one; otherwise the caching allocator holds + # both reservations and we accumulate ~model-size of VRAM. + del param torch.distributed.barrier() + + # Release the sharded view and force the allocator to give back the + # gather buffers. + del sharded_state_dict + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() elif self.distributed_type == DistributedType.FSDP: from torch.distributed.fsdp import ( FullStateDictConfig,