From e0eed7542d0ec9853d3f47915d69f8a4994dd01d Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 26 Feb 2026 18:21:32 +0700 Subject: [PATCH] fix: fsdp2 init_sharded_param load int8/uint4 dtensor as require_grad=true on init --- src/axolotl/monkeypatch/fsdp2_qlora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/fsdp2_qlora.py b/src/axolotl/monkeypatch/fsdp2_qlora.py index 04d0d1971..8076e9b14 100644 --- a/src/axolotl/monkeypatch/fsdp2_qlora.py +++ b/src/axolotl/monkeypatch/fsdp2_qlora.py @@ -42,8 +42,10 @@ def apply_init_sharded_param_patch(): ) self.sharded_param = self.to_sharded_dtensor(self.sharded_param) else: - self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) - self.sharded_param.requires_grad_(param.requires_grad)""" + self.sharded_param = nn.Parameter( + self.to_sharded_dtensor(sharded_param), + requires_grad=param.requires_grad, + )""" # Apply the replacement if original_param_creation in original_source: