* Fix Axolotl ReLoRA optimizer reset scope * fix: make relora reset method honor relora_prune_ratio When relora_prune_method='reset' and relora_prune_ratio is explicitly set, the ratio was silently ignored and replaced with the hardcoded _FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset and 0.9 for other methods. reset_optimizer now uses the same random pruning path for both 'random' and 'reset'. Also consolidate three-layer default mismatch: schema default for relora_prune_method is now 'magnitude' (single canonical source); dataclass defaults for both fields changed to None to eliminate the conflicting fallback layer. Tests updated: removed the test case that verified the old broken behavior (reset ignoring ratio), added two cases proving reset honors the passed ratio. E2E reset fixture now uses ratio=0.5 to make it unambiguous that the ratio is honored. * Fix ReLoRA uint8 pruning regression --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
187 lines
5.7 KiB
Python
187 lines
5.7 KiB
Python
"""Unit tests for axolotl.monkeypatch.relora.reset_optimizer."""
|
|
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from axolotl.monkeypatch.relora import (
|
|
magnitude_pruning_,
|
|
random_pruning_,
|
|
reset_optimizer,
|
|
)
|
|
|
|
ADAM_KEYS = ["exp_avg", "exp_avg_sq"]
|
|
|
|
|
|
def _build_optimizer_with_state(seed: int = 0):
|
|
"""Build a tiny optimizer over LoRA-shaped + non-LoRA params with populated state."""
|
|
torch.manual_seed(seed)
|
|
lora_a = nn.Parameter(torch.randn(8, 32))
|
|
lora_b = nn.Parameter(torch.randn(32, 8))
|
|
extra = nn.Parameter(torch.randn(64, 32))
|
|
|
|
optimizer = torch.optim.AdamW([lora_a, lora_b, extra], lr=1e-3)
|
|
for _ in range(2):
|
|
loss = (
|
|
(lora_a * torch.randn_like(lora_a)).sum()
|
|
+ (lora_b * torch.randn_like(lora_b)).sum()
|
|
+ (extra * torch.randn_like(extra)).sum()
|
|
)
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
return optimizer, lora_a, lora_b, extra
|
|
|
|
|
|
def test_reset_optimizer_only_touches_reset_params():
|
|
"""State for params NOT in reset_params must be byte-identical after reset."""
|
|
optimizer, lora_a, lora_b, extra = _build_optimizer_with_state()
|
|
|
|
extra_avg_before = optimizer.state[extra]["exp_avg"].clone()
|
|
extra_avg_sq_before = optimizer.state[extra]["exp_avg_sq"].clone()
|
|
|
|
reset_optimizer(
|
|
optimizer,
|
|
reset_params=[lora_a, lora_b],
|
|
optimizer_state_keys=ADAM_KEYS,
|
|
prune_method="magnitude",
|
|
prune_ratio=0.9,
|
|
)
|
|
|
|
assert torch.equal(optimizer.state[extra]["exp_avg"], extra_avg_before)
|
|
assert torch.equal(optimizer.state[extra]["exp_avg_sq"], extra_avg_sq_before)
|
|
|
|
|
|
def test_reset_optimizer_actually_prunes_lora_state():
|
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
|
|
|
reset_optimizer(
|
|
optimizer,
|
|
reset_params=[lora_a, lora_b],
|
|
optimizer_state_keys=ADAM_KEYS,
|
|
prune_method="magnitude",
|
|
prune_ratio=0.9,
|
|
)
|
|
|
|
for param in (lora_a, lora_b):
|
|
for key in ADAM_KEYS:
|
|
zero_frac = (optimizer.state[param][key] == 0).float().mean().item()
|
|
assert zero_frac >= 0.85
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"method,ratio,expected_zero_frac",
|
|
[
|
|
("magnitude", 0.9, 0.9),
|
|
("magnitude", 0.99, 0.99),
|
|
("random", 0.9, 0.9),
|
|
("random", 0.5, 0.5),
|
|
# reset uses random pruning; relora_prune_ratio must be honored, not ignored.
|
|
("reset", 0.9, 0.9),
|
|
("reset", 0.5, 0.5),
|
|
],
|
|
)
|
|
def test_prune_methods(method, ratio, expected_zero_frac):
|
|
"""Each method zeros approximately the expected fraction."""
|
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state(seed=42)
|
|
|
|
reset_optimizer(
|
|
optimizer,
|
|
reset_params=[lora_a, lora_b],
|
|
optimizer_state_keys=ADAM_KEYS,
|
|
prune_method=method,
|
|
prune_ratio=ratio,
|
|
)
|
|
|
|
total = 0
|
|
zeros = 0
|
|
for param in (lora_a, lora_b):
|
|
for key in ADAM_KEYS:
|
|
tensor = optimizer.state[param][key]
|
|
total += tensor.numel()
|
|
zeros += (tensor == 0).sum().item()
|
|
|
|
actual = zeros / total
|
|
tolerance = 0.02 if method == "magnitude" else 0.05
|
|
assert math.isclose(actual, expected_zero_frac, abs_tol=tolerance)
|
|
|
|
|
|
def test_reset_optimizer_skips_keys_not_in_state_keys():
|
|
"""Keys present in optimizer state but not in optimizer_state_keys are untouched."""
|
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
|
|
|
exp_avg_sq_before = optimizer.state[lora_a]["exp_avg_sq"].clone()
|
|
|
|
reset_optimizer(
|
|
optimizer,
|
|
reset_params=[lora_a, lora_b],
|
|
optimizer_state_keys=["exp_avg"],
|
|
prune_method="magnitude",
|
|
prune_ratio=0.9,
|
|
)
|
|
|
|
assert torch.equal(optimizer.state[lora_a]["exp_avg_sq"], exp_avg_sq_before)
|
|
|
|
|
|
def test_reset_optimizer_handles_param_with_empty_state():
|
|
"""Params with no optimizer state are skipped silently."""
|
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
|
orphan = nn.Parameter(torch.randn(4, 4))
|
|
|
|
reset_optimizer(
|
|
optimizer,
|
|
reset_params=[lora_a, lora_b, orphan],
|
|
optimizer_state_keys=ADAM_KEYS,
|
|
prune_method="magnitude",
|
|
prune_ratio=0.9,
|
|
)
|
|
|
|
assert orphan not in optimizer.state or not optimizer.state[orphan]
|
|
|
|
|
|
def test_unknown_prune_method_raises():
|
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
|
|
|
with pytest.raises(ValueError, match="Unknown prune_method"):
|
|
reset_optimizer(
|
|
optimizer,
|
|
reset_params=[lora_a, lora_b],
|
|
optimizer_state_keys=ADAM_KEYS,
|
|
prune_method="bogus", # type: ignore[arg-type]
|
|
prune_ratio=0.9,
|
|
)
|
|
|
|
|
|
def test_pruning_helpers_are_inplace():
|
|
"""magnitude_pruning_ and random_pruning_ must mutate via tensor.mul_."""
|
|
tensor = torch.randn(64)
|
|
ptr_before = tensor.data_ptr()
|
|
magnitude_pruning_(tensor, 0.5)
|
|
assert tensor.data_ptr() == ptr_before
|
|
|
|
tensor = torch.randn(64)
|
|
ptr_before = tensor.data_ptr()
|
|
random_pruning_(tensor, 0.5)
|
|
assert tensor.data_ptr() == ptr_before
|
|
|
|
|
|
def test_pruning_helpers_support_uint8_tensors():
|
|
"""Both pruning helpers must work on uint8 optimizer state tensors."""
|
|
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
|
magnitude_pruning_(tensor, 0.9)
|
|
|
|
assert tensor.dtype == torch.uint8
|
|
magnitude_zero_frac = (tensor == 0).float().mean().item()
|
|
assert 0.85 <= magnitude_zero_frac <= 0.95
|
|
|
|
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
|
with torch.random.fork_rng(devices=[]):
|
|
torch.manual_seed(1234)
|
|
random_pruning_(tensor, 0.9)
|
|
|
|
assert tensor.dtype == torch.uint8
|
|
random_zero_frac = (tensor == 0).float().mean().item()
|
|
assert 0.85 <= random_zero_frac <= 0.95
|