[fix][fsdp2] clone sharded param so original full size shard can be gc'ed (#3597) [skip ci]

This commit is contained in:
Wing Lian
2026-04-11 20:22:35 -04:00
committed by GitHub
parent 122b50bad6
commit e2f69828d2

View File

@@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict(
sharded_meta_param.placements, sharded_meta_param.placements,
src_data_rank=0, src_data_rank=0,
) )
# Clone the local shard to allow full_tensor to be freed.
if (
sharded_param._local_tensor.untyped_storage().size()
> sharded_param._local_tensor.nelement()
* sharded_param._local_tensor.element_size()
):
sharded_param = sharded_param.clone()
else: else:
# Non-sharded parameters # Non-sharded parameters
if _accelerator.is_main_process: if _accelerator.is_main_process: