diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index cf8056ca0..e086885bb 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict( sharded_meta_param.placements, src_data_rank=0, ) + # Clone the local shard to allow full_tensor to be freed. + if ( + sharded_param._local_tensor.untyped_storage().size() + > sharded_param._local_tensor.nelement() + * sharded_param._local_tensor.element_size() + ): + sharded_param = sharded_param.clone() else: # Non-sharded parameters if _accelerator.is_main_process: