fix type annotations (#1941) [skip ci]
This commit is contained in:
@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: list[str],
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: List[str],
|
||||
prune_ratio: float = 0.9,
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
|
||||
Reference in New Issue
Block a user