fix: FSDP FULL_STATE_DICT oom from memory leak (#3635)

* memory clean up for fsdp full state dict

* Update src/axolotl/monkeypatch/accelerate/fsdp2.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
VED
2026-05-05 20:52:35 +05:30
committed by GitHub
parent e4032fc90f
commit c15f6cffe2
2 changed files with 23 additions and 2 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import gc
import json
import math
import os
@@ -800,7 +801,14 @@ class AxolotlTrainer(
with open(tokens_state_path, "w", encoding="utf-8") as f:
json.dump(tokens_state, f)
return super()._save_checkpoint(model, trial, **kwargs)
result = super()._save_checkpoint(model, trial, **kwargs)
# Reclaim VRAM held by the FSDP full-state-dict gather.
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
def _save(self, output_dir: Optional[str] = None, state_dict=None):

View File

@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
import copy
import functools
import gc
import os
import sys
@@ -161,6 +162,7 @@ def get_state_dict(self, model, unwrap=True):
state_dict = {}
sharded_state_dict = model.state_dict()
is_rank_zero = torch.distributed.get_rank() == 0
for param_name, param in sharded_state_dict.items():
if param.is_cpu:
param = param.to(torch.device("cuda"))
@@ -168,9 +170,20 @@ def get_state_dict(self, model, unwrap=True):
if isinstance(param, DTensor):
param = param.full_tensor()
if torch.distributed.get_rank() == 0:
if is_rank_zero:
state_dict[param_name] = param.cpu()
# Drop the GPU-resident gathered tensor before the next iteration
# allocates the next one; otherwise the caching allocator holds
# both reservations and we accumulate ~model-size of VRAM.
del param
torch.distributed.barrier()
# Release the sharded view and force the allocator to give back the
# gather buffers.
del sharded_state_dict
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif self.distributed_type == DistributedType.FSDP:
from torch.distributed.fsdp import (
FullStateDictConfig,