Compare commits

...

1 Commits

Author SHA1 Message Date
Salman Mohammadi
1d0562dedd adding fp32 support 2025-09-26 16:32:09 +00:00

View File

@@ -368,6 +368,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
# removing the call above leads to extra memory usage as explained in the comment above # removing the call above leads to extra memory usage as explained in the comment above
if hasattr(model, "tie_weights"): if hasattr(model, "tie_weights"):
model.tie_weights() model.tie_weights()
model = model.to(torch.float32)
return model return model