diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index e4352cbe3..9d246cb17 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio): def reset_optimizer( optimizer: torch.optim.Optimizer, *, - reset_params: list[str], # where str is the key to a torch.nn.Parameter - optimizer_state_keys: list[str], + reset_params: List[str], # where str is the key to a torch.nn.Parameter + optimizer_state_keys: List[str], prune_ratio: float = 0.9, ): pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)