From e2f69828d21afc969b5a2824cc666813147aa101 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 11 Apr 2026 20:22:35 -0400 Subject: [PATCH] [fix][fsdp2] clone sharded param so original full size shard can be gc'ed (#3597) [skip ci] --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index cf8056ca0..e086885bb 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict( sharded_meta_param.placements, src_data_rank=0, ) + # Clone the local shard to allow full_tensor to be freed. + if ( + sharded_param._local_tensor.untyped_storage().size() + > sharded_param._local_tensor.nelement() + * sharded_param._local_tensor.element_size() + ): + sharded_param = sharded_param.clone() else: # Non-sharded parameters if _accelerator.is_main_process: