adding fp32 support
This commit is contained in:
@@ -368,6 +368,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
|||||||
# removing the call above leads to extra memory usage as explained in the comment above
|
# removing the call above leads to extra memory usage as explained in the comment above
|
||||||
if hasattr(model, "tie_weights"):
|
if hasattr(model, "tie_weights"):
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
model = model.to(torch.float32)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user