diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index d7270679c..19ccd6f21 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -254,9 +254,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: "offload_policy": fsdp2_plugin.cpu_offload, # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), - "mesh": accelerator.state.device_mesh[ - accelerator.state.parallelism_config.model_shard_dim_names - ], + "mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)], } model_has_params4bit = False diff --git a/src/axolotl/utils/schemas/distributed.py b/src/axolotl/utils/schemas/distributed.py new file mode 100644 index 000000000..e69de29bb