fix: FSDP FULL_STATE_DICT oom from memory leak (#3635)
* memory clean up for fsdp full state dict * Update src/axolotl/monkeypatch/accelerate/fsdp2.py Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -800,7 +801,14 @@ class AxolotlTrainer(
|
||||
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokens_state, f)
|
||||
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
result = super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
# Reclaim VRAM held by the FSDP full-state-dict gather.
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return result
|
||||
|
||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||
|
||||
@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
|
||||
@@ -161,6 +162,7 @@ def get_state_dict(self, model, unwrap=True):
|
||||
|
||||
state_dict = {}
|
||||
sharded_state_dict = model.state_dict()
|
||||
is_rank_zero = torch.distributed.get_rank() == 0
|
||||
for param_name, param in sharded_state_dict.items():
|
||||
if param.is_cpu:
|
||||
param = param.to(torch.device("cuda"))
|
||||
@@ -168,9 +170,20 @@ def get_state_dict(self, model, unwrap=True):
|
||||
if isinstance(param, DTensor):
|
||||
param = param.full_tensor()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if is_rank_zero:
|
||||
state_dict[param_name] = param.cpu()
|
||||
# Drop the GPU-resident gathered tensor before the next iteration
|
||||
# allocates the next one; otherwise the caching allocator holds
|
||||
# both reservations and we accumulate ~model-size of VRAM.
|
||||
del param
|
||||
torch.distributed.barrier()
|
||||
|
||||
# Release the sharded view and force the allocator to give back the
|
||||
# gather buffers.
|
||||
del sharded_state_dict
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp import (
|
||||
FullStateDictConfig,
|
||||
|
||||
Reference in New Issue
Block a user