fix: FSDP FULL_STATE_DICT oom from memory leak (#3635)
* memory clean up for fsdp full state dict * Update src/axolotl/monkeypatch/accelerate/fsdp2.py Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import gc
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -800,7 +801,14 @@ class AxolotlTrainer(
|
|||||||
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
with open(tokens_state_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(tokens_state, f)
|
json.dump(tokens_state, f)
|
||||||
|
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
result = super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
# Reclaim VRAM held by the FSDP full-state-dict gather.
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interatio
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -161,6 +162,7 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
sharded_state_dict = model.state_dict()
|
sharded_state_dict = model.state_dict()
|
||||||
|
is_rank_zero = torch.distributed.get_rank() == 0
|
||||||
for param_name, param in sharded_state_dict.items():
|
for param_name, param in sharded_state_dict.items():
|
||||||
if param.is_cpu:
|
if param.is_cpu:
|
||||||
param = param.to(torch.device("cuda"))
|
param = param.to(torch.device("cuda"))
|
||||||
@@ -168,9 +170,20 @@ def get_state_dict(self, model, unwrap=True):
|
|||||||
if isinstance(param, DTensor):
|
if isinstance(param, DTensor):
|
||||||
param = param.full_tensor()
|
param = param.full_tensor()
|
||||||
|
|
||||||
if torch.distributed.get_rank() == 0:
|
if is_rank_zero:
|
||||||
state_dict[param_name] = param.cpu()
|
state_dict[param_name] = param.cpu()
|
||||||
|
# Drop the GPU-resident gathered tensor before the next iteration
|
||||||
|
# allocates the next one; otherwise the caching allocator holds
|
||||||
|
# both reservations and we accumulate ~model-size of VRAM.
|
||||||
|
del param
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
# Release the sharded view and force the allocator to give back the
|
||||||
|
# gather buffers.
|
||||||
|
del sharded_state_dict
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
elif self.distributed_type == DistributedType.FSDP:
|
elif self.distributed_type == DistributedType.FSDP:
|
||||||
from torch.distributed.fsdp import (
|
from torch.distributed.fsdp import (
|
||||||
FullStateDictConfig,
|
FullStateDictConfig,
|
||||||
|
|||||||
Reference in New Issue
Block a user