adding fp32 support

This commit is contained in:
Salman Mohammadi
2025-09-26 16:32:09 +00:00
parent 7fa8ac40cd
commit 1d0562dedd

View File

@@ -368,6 +368,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
# removing the call above leads to extra memory usage as explained in the comment above
if hasattr(model, "tie_weights"):
model.tie_weights()
model = model.to(torch.float32)
return model