update fsdp2 patch

This commit is contained in:
Salman Mohammadi
2025-07-23 16:53:03 +01:00
parent b3c04dd9fe
commit bc2bc688d8
2 changed files with 1 additions and 3 deletions

View File

@@ -254,9 +254,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": accelerator.state.device_mesh[
accelerator.state.parallelism_config.model_shard_dim_names
],
"mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)],
}
model_has_params4bit = False

View File