[fix][fsdp2] clone sharded param so original full size shard can be gc'ed (#3597) [skip ci]
This commit is contained in:
@@ -60,6 +60,13 @@ def fsdp2_load_full_state_dict(
|
||||
sharded_meta_param.placements,
|
||||
src_data_rank=0,
|
||||
)
|
||||
# Clone the local shard to allow full_tensor to be freed.
|
||||
if (
|
||||
sharded_param._local_tensor.untyped_storage().size()
|
||||
> sharded_param._local_tensor.nelement()
|
||||
* sharded_param._local_tensor.element_size()
|
||||
):
|
||||
sharded_param = sharded_param.clone()
|
||||
else:
|
||||
# Non-sharded parameters
|
||||
if _accelerator.is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user