fix type annotations (#1941) [skip ci]

This commit is contained in:
aarush gupta
2024-10-09 13:03:16 -07:00
committed by GitHub
parent a560593b1d
commit dee77232fe

View File

@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
def reset_optimizer( def reset_optimizer(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
*, *,
reset_params: list[str], # where str is the key to a torch.nn.Parameter reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str], optimizer_state_keys: List[str],
prune_ratio: float = 0.9, prune_ratio: float = 0.9,
): ):
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)