diff --git a/src/axolotl/monkeypatch/fsdp2_qlora.py b/src/axolotl/monkeypatch/fsdp2_qlora.py index 04d0d1971..8076e9b14 100644 --- a/src/axolotl/monkeypatch/fsdp2_qlora.py +++ b/src/axolotl/monkeypatch/fsdp2_qlora.py @@ -42,8 +42,10 @@ def apply_init_sharded_param_patch(): ) self.sharded_param = self.to_sharded_dtensor(self.sharded_param) else: - self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) - self.sharded_param.requires_grad_(param.requires_grad)""" + self.sharded_param = nn.Parameter( + self.to_sharded_dtensor(sharded_param), + requires_grad=param.requires_grad, + )""" # Apply the replacement if original_param_creation in original_source: