fix type annotations (#1941) [skip ci]

This commit is contained in:
aarush gupta
2024-10-09 13:03:16 -07:00
committed by GitHub
parent a560593b1d
commit dee77232fe

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
prune_ratio: float = 0.9,
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)