fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init
This commit is contained in:
@@ -42,8 +42,10 @@ def apply_init_sharded_param_patch():
|
|||||||
)
|
)
|
||||||
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
self.sharded_param = self.to_sharded_dtensor(self.sharded_param)
|
||||||
else:
|
else:
|
||||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
self.sharded_param = nn.Parameter(
|
||||||
self.sharded_param.requires_grad_(param.requires_grad)"""
|
self.to_sharded_dtensor(sharded_param),
|
||||||
|
requires_grad=param.requires_grad,
|
||||||
|
)"""
|
||||||
|
|
||||||
# Apply the replacement
|
# Apply the replacement
|
||||||
if original_param_creation in original_source:
|
if original_param_creation in original_source:
|
||||||
|
|||||||
Reference in New Issue
Block a user