diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index d8ba02cb2..493cb89c5 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -368,6 +368,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: # removing the call above leads to extra memory usage as explained in the comment above if hasattr(model, "tie_weights"): model.tie_weights() + model = model.to(torch.float32) return model