From bc2bc688d80f2d39ad21f06bf2e68f768e99db31 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Wed, 23 Jul 2025 16:53:03 +0100 Subject: [PATCH] update fsdp2 patch --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 4 +--- src/axolotl/utils/schemas/distributed.py | 0 2 files changed, 1 insertion(+), 3 deletions(-) create mode 100644 src/axolotl/utils/schemas/distributed.py diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index d7270679c..19ccd6f21 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -254,9 +254,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: "offload_policy": fsdp2_plugin.cpu_offload, # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), - "mesh": accelerator.state.device_mesh[ - accelerator.state.parallelism_config.model_shard_dim_names - ], + "mesh": accelerator.torch_device_mesh[tuple(accelerator.parallelism_config.model_shard_dim_names)], } model_has_params4bit = False diff --git a/src/axolotl/utils/schemas/distributed.py b/src/axolotl/utils/schemas/distributed.py new file mode 100644 index 000000000..e69de29bb