fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init
This commit is contained in:
@@ -42,8 +42,10 @@ def apply_init_sharded_param_patch():
|
||||
)
|
||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||
else:
|
||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
||||
self.sharded_param = nn.Parameter(
|
||||
self.to_sharded_dtensor(sharded_param),
|
||||
requires_grad=param.requires_grad,
|
||||
)"""
|
||||
|
||||
# Apply the replacement
|
||||
if original_param_creation in original_source:
|
||||
|
||||
Reference in New Issue
Block a user