fix: fsdp2 init_sharded_param load int8/uint4 dtensor as

require_grad=true on init
This commit is contained in:
NanoCode012
2026-02-26 18:21:32 +07:00
parent 5f6fcd1f7e
commit e0eed7542d

View File

@@ -42,8 +42,10 @@ def apply_init_sharded_param_patch():
)
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
else:
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
self.sharded_param.requires_grad_(param.requires_grad)"""
self.sharded_param = nn.Parameter(
self.to_sharded_dtensor(sharded_param),
requires_grad=param.requires_grad,
)"""
# Apply the replacement
if original_param_creation in original_source: