Fix Axolotl ReLoRA optimizer reset scope (#3646)

* Fix Axolotl ReLoRA optimizer reset scope
* fix: make relora reset method honor relora_prune_ratio

When relora_prune_method='reset' and relora_prune_ratio is explicitly
set, the ratio was silently ignored and replaced with the hardcoded
_FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to
ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset
and 0.9 for other methods. reset_optimizer now uses the same random
pruning path for both 'random' and 'reset'.

Also consolidate three-layer default mismatch: schema default for
relora_prune_method is now 'magnitude' (single canonical source);
dataclass defaults for both fields changed to None to eliminate the
conflicting fallback layer.

Tests updated: removed the test case that verified the old broken
behavior (reset ignoring ratio), added two cases proving reset honors
the passed ratio. E2E reset fixture now uses ratio=0.5 to make it
unambiguous that the ratio is honored.

* Fix ReLoRA uint8 pruning regression

---------

Signed-off-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
Wing Lian
2026-05-09 17:52:35 -04:00
committed by GitHub
parent 5352d41d32
commit e2f01de0e8
7 changed files with 361 additions and 37 deletions

View File

@@ -72,7 +72,7 @@ jobs:
exclude: exclude:
- python_version: "3.14" - python_version: "3.14"
pytorch_version: "2.9.1" pytorch_version: "2.9.1"
timeout-minutes: 20 timeout-minutes: 25
steps: steps:
- name: cleanup node - name: cleanup node

View File

@@ -286,10 +286,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
) )
if self.cfg.relora and self.cfg.jagged_restart_steps: if self.cfg.relora and self.cfg.jagged_restart_steps:
if self.cfg.relora_prune_ratio: if self.cfg.relora_prune_ratio is not None:
training_arguments_kwargs["relora_prune_ratio"] = ( training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio self.cfg.relora_prune_ratio
) )
if self.cfg.relora_prune_method:
training_arguments_kwargs["relora_prune_method"] = (
self.cfg.relora_prune_method
)
if self.cfg.jagged_restart_steps: if self.cfg.jagged_restart_steps:
training_arguments_kwargs["jagged_restart_steps"] = ( training_arguments_kwargs["jagged_restart_steps"] = (

View File

@@ -83,13 +83,18 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "The number of processes to use for data processing"}, metadata={"help": "The number of processes to use for data processing"},
) )
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field( relora_prune_ratio: Optional[float] = field(
default=0.9, default=None,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, metadata={
"help": (
"prune ratio for optimizer state pruning; "
"defaults to 0.999 for reset method, 0.9 for others"
)
},
)
relora_prune_method: Optional[str] = field(
default=None,
metadata={"help": "optimizer state pruning method: magnitude | random | reset"},
) )
jagged_restart_steps: Optional[int] = field( jagged_restart_steps: Optional[int] = field(
default=None, default=None,

View File

@@ -6,9 +6,8 @@ import os.path
import shutil import shutil
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, List, Union from typing import Dict, List, Literal, Union
import bitsandbytes as bnb
import peft import peft
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@@ -28,9 +27,15 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
try:
import bitsandbytes as bnb
except ImportError: # pragma: no cover - optional dependency for 8-bit merge paths
bnb = None
@torch.no_grad() @torch.no_grad()
def magnitude_pruning_(tensor, prune_ratio): def magnitude_pruning_(tensor, prune_ratio):
"""Zero the lowest ``prune_ratio`` fraction of values by absolute magnitude, in place."""
tensor_magnitude = torch.abs(tensor) tensor_magnitude = torch.abs(tensor)
threshold = torch.quantile( threshold = torch.quantile(
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
@@ -40,15 +45,43 @@ def magnitude_pruning_(tensor, prune_ratio):
tensor.mul_(mask.to(dtype=tensor.dtype)) tensor.mul_(mask.to(dtype=tensor.dtype))
@torch.no_grad()
def random_pruning_(tensor, prune_ratio):
"""Zero a random ``prune_ratio`` fraction of values, in place."""
mask = (
torch.rand(tensor.shape, dtype=torch.float32, device=tensor.device)
> prune_ratio
)
tensor.mul_(mask.to(dtype=tensor.dtype))
# 0.999 mirrors the reference implementation. True zeroing breaks
# ZeroRedundancyOptimizer.consolidate_state_dict; see Guitaricet/relora's
# peft_pretraining/training_utils.py for the original note on this.
_FULL_RESET_RATIO = 0.999
def reset_optimizer( def reset_optimizer(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
*, *,
reset_params: List[str], # where str is the key to a torch.nn.Parameter reset_params: List[torch.nn.Parameter],
optimizer_state_keys: List[str], optimizer_state_keys: List[str],
optimizer_magnitude_pruning: float = 0.9, prune_method: Literal["magnitude", "random", "reset"] = "magnitude",
prune_ratio: float = 0.9,
): ):
# pylint:disable=unused-argument """Prune optimizer state for ``reset_params`` only."""
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning) if prune_method == "magnitude":
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
elif prune_method in ("random", "reset"):
# "reset" is random pruning at a near-full ratio; the caller is responsible
# for supplying the appropriate prune_ratio (see ReLoRACallback.on_step_begin).
pruning_fn = partial(random_pruning_, prune_ratio=prune_ratio)
else:
raise ValueError(
f"Unknown prune_method {prune_method!r}; expected one of "
"'magnitude', 'random', 'reset'"
)
n_zeros = 0 n_zeros = 0
n_total = 0 n_total = 0
@@ -56,22 +89,22 @@ def reset_optimizer(
if isinstance(optimizer, ZeroRedundancyOptimizer): if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state optimizer_state = optimizer.optim.state
for group in optimizer.param_groups: for param in reset_params:
for param in group["params"]: state = optimizer_state.get(param, {})
state = optimizer_state[param] if not state:
for key, value in state.items(): continue
if key not in optimizer_state_keys: for key in optimizer_state_keys:
value = state.get(key)
if value is None or not torch.is_tensor(value):
continue
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
continue continue
if torch.is_tensor(value): raise
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc
_zeroed = n_zeros / (1e-7 + n_total) * 100 _zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}") LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
@@ -82,11 +115,12 @@ class ReLoRACallback(TrainerCallback):
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints""" """Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
def __init__(self, cfg: DictDefault): def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.jagged_restart_steps self.jagged_restart_steps = cfg.jagged_restart_steps
self.cpu_offload = cfg.relora_cpu_offload self.cpu_offload = cfg.relora_cpu_offload
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model self.last_full_model = cfg.base_model
self.resume_from_checkpoint = cfg.resume_from_checkpoint self.resume_from_checkpoint = cfg.resume_from_checkpoint
self.prune_method = cfg.relora_prune_method or "magnitude"
if not os.path.exists(self.last_full_model): if not os.path.exists(self.last_full_model):
self.last_full_model = str(Path(snapshot_download(cfg.base_model))) self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
@@ -128,7 +162,7 @@ class ReLoRACallback(TrainerCallback):
): ):
if not optimizer: if not optimizer:
optimizer = state.optimizer optimizer = state.optimizer
if state.global_step > 0 and state.global_step % self.relora_steps == 0: if state.global_step > 0 and state.global_step % self.jagged_restart_steps == 0:
checkpoint_folder = os.path.join( checkpoint_folder = os.path.join(
args.output_dir, args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
@@ -144,7 +178,7 @@ class ReLoRACallback(TrainerCallback):
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA") raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
lora_params = [ lora_params = [
n p
for n, p in model.named_parameters() for n, p in model.named_parameters()
if p.requires_grad and "lora_" in n if p.requires_grad and "lora_" in n
] ]
@@ -166,11 +200,19 @@ class ReLoRACallback(TrainerCallback):
actually_save=is_main_process(), actually_save=is_main_process(),
cpu_offload=self.cpu_offload, cpu_offload=self.cpu_offload,
) )
# When relora_prune_ratio is not set, use _FULL_RESET_RATIO for
# "reset" (paper-style near-full reset) and 0.9 for other methods.
prune_ratio = args.relora_prune_ratio
if prune_ratio is None:
prune_ratio = (
_FULL_RESET_RATIO if self.prune_method == "reset" else 0.9
)
reset_optimizer( reset_optimizer(
optimizer, optimizer,
reset_params=lora_params, reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys, optimizer_state_keys=optimizer_state_keys,
optimizer_magnitude_pruning=args.relora_prune_ratio, prune_method=self.prune_method,
prune_ratio=prune_ratio,
) )
if self.quantized: if self.quantized:
@@ -191,8 +233,8 @@ class ReLoRACallback(TrainerCallback):
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora" args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
) )
if ( if (
state.global_step >= self.relora_steps state.global_step >= self.jagged_restart_steps
and state.global_step % self.relora_steps != 0 and state.global_step % self.jagged_restart_steps != 0
): ):
if self.quantized: if self.quantized:
if is_main_process() and self.last_full_model != checkpoint_folder: if is_main_process() and self.last_full_model != checkpoint_folder:
@@ -320,6 +362,8 @@ def update_weights(
target.weight.data = new_weight.cpu() target.weight.data = new_weight.cpu()
target.to(device) target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt): elif isinstance(target, peft.tuners.lora.Linear8bitLt):
if bnb is None:
raise ImportError("bitsandbytes is required to merge 8-bit LoRA weights")
target.weight.data = ( target.weight.data = (
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
) )

View File

@@ -240,8 +240,28 @@ class ReLoRAConfig(BaseModel):
) )
relora_prune_ratio: float | None = Field( relora_prune_ratio: float | None = Field(
default=None, default=None,
ge=0.0,
le=1.0,
json_schema_extra={ json_schema_extra={
"description": "threshold for optimizer magnitude when pruning" "description": (
"Fraction of optimizer state values to zero on each ReLoRA restart. "
"When relora_prune_method='reset' and this is omitted, defaults to "
"0.999 (paper-style near-full reset). For other methods, defaults to 0.9."
)
},
)
relora_prune_method: Literal["magnitude", "random", "reset"] | None = Field(
default="magnitude",
json_schema_extra={
"description": (
"Optimizer state pruning method on each ReLoRA restart. "
"'magnitude' (default) keeps top-k by absolute value; "
"'random' keeps a random subset at relora_prune_ratio; "
"'reset' uses near-full random pruning (default ratio 0.999, "
"honoring relora_prune_ratio when explicitly set). "
"Paper-style recipe: relora_prune_method='reset' with no "
"relora_prune_ratio, equivalent to 'random' with ratio=0.999."
)
}, },
) )
relora_cpu_offload: bool | None = Field( relora_cpu_offload: bool | None = Field(

View File

@@ -56,7 +56,72 @@ class TestReLoraLlama(unittest.TestCase):
], ],
"warmup_steps": 10, "warmup_steps": 10,
"num_epochs": 2, "num_epochs": 2,
"max_steps": 105, # at least 2x relora_steps "max_steps": 105, # at least 2x restart cadence
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
"Relora model checkpoint not found"
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
)
@with_temp_dir
def test_relora_reset_method(self, temp_dir):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora": True,
"jagged_restart_steps": 50,
"jagged_restart_warmup_steps": 10,
"jagged_restart_anneal_steps": 10,
"relora_prune_ratio": 0.5, # explicitly honored by reset (not ignored)
"relora_prune_method": "reset",
"relora_cpu_offload": True,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"warmup_steps": 10,
"num_epochs": 2,
"max_steps": 105,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,

View File

@@ -0,0 +1,186 @@
"""Unit tests for axolotl.monkeypatch.relora.reset_optimizer."""
import math
import pytest
import torch
import torch.nn as nn
from axolotl.monkeypatch.relora import (
magnitude_pruning_,
random_pruning_,
reset_optimizer,
)
ADAM_KEYS = ["exp_avg", "exp_avg_sq"]
def _build_optimizer_with_state(seed: int = 0):
"""Build a tiny optimizer over LoRA-shaped + non-LoRA params with populated state."""
torch.manual_seed(seed)
lora_a = nn.Parameter(torch.randn(8, 32))
lora_b = nn.Parameter(torch.randn(32, 8))
extra = nn.Parameter(torch.randn(64, 32))
optimizer = torch.optim.AdamW([lora_a, lora_b, extra], lr=1e-3)
for _ in range(2):
loss = (
(lora_a * torch.randn_like(lora_a)).sum()
+ (lora_b * torch.randn_like(lora_b)).sum()
+ (extra * torch.randn_like(extra)).sum()
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return optimizer, lora_a, lora_b, extra
def test_reset_optimizer_only_touches_reset_params():
"""State for params NOT in reset_params must be byte-identical after reset."""
optimizer, lora_a, lora_b, extra = _build_optimizer_with_state()
extra_avg_before = optimizer.state[extra]["exp_avg"].clone()
extra_avg_sq_before = optimizer.state[extra]["exp_avg_sq"].clone()
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method="magnitude",
prune_ratio=0.9,
)
assert torch.equal(optimizer.state[extra]["exp_avg"], extra_avg_before)
assert torch.equal(optimizer.state[extra]["exp_avg_sq"], extra_avg_sq_before)
def test_reset_optimizer_actually_prunes_lora_state():
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method="magnitude",
prune_ratio=0.9,
)
for param in (lora_a, lora_b):
for key in ADAM_KEYS:
zero_frac = (optimizer.state[param][key] == 0).float().mean().item()
assert zero_frac >= 0.85
@pytest.mark.parametrize(
"method,ratio,expected_zero_frac",
[
("magnitude", 0.9, 0.9),
("magnitude", 0.99, 0.99),
("random", 0.9, 0.9),
("random", 0.5, 0.5),
# reset uses random pruning; relora_prune_ratio must be honored, not ignored.
("reset", 0.9, 0.9),
("reset", 0.5, 0.5),
],
)
def test_prune_methods(method, ratio, expected_zero_frac):
"""Each method zeros approximately the expected fraction."""
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state(seed=42)
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method=method,
prune_ratio=ratio,
)
total = 0
zeros = 0
for param in (lora_a, lora_b):
for key in ADAM_KEYS:
tensor = optimizer.state[param][key]
total += tensor.numel()
zeros += (tensor == 0).sum().item()
actual = zeros / total
tolerance = 0.02 if method == "magnitude" else 0.05
assert math.isclose(actual, expected_zero_frac, abs_tol=tolerance)
def test_reset_optimizer_skips_keys_not_in_state_keys():
"""Keys present in optimizer state but not in optimizer_state_keys are untouched."""
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
exp_avg_sq_before = optimizer.state[lora_a]["exp_avg_sq"].clone()
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=["exp_avg"],
prune_method="magnitude",
prune_ratio=0.9,
)
assert torch.equal(optimizer.state[lora_a]["exp_avg_sq"], exp_avg_sq_before)
def test_reset_optimizer_handles_param_with_empty_state():
"""Params with no optimizer state are skipped silently."""
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
orphan = nn.Parameter(torch.randn(4, 4))
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b, orphan],
optimizer_state_keys=ADAM_KEYS,
prune_method="magnitude",
prune_ratio=0.9,
)
assert orphan not in optimizer.state or not optimizer.state[orphan]
def test_unknown_prune_method_raises():
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
with pytest.raises(ValueError, match="Unknown prune_method"):
reset_optimizer(
optimizer,
reset_params=[lora_a, lora_b],
optimizer_state_keys=ADAM_KEYS,
prune_method="bogus", # type: ignore[arg-type]
prune_ratio=0.9,
)
def test_pruning_helpers_are_inplace():
"""magnitude_pruning_ and random_pruning_ must mutate via tensor.mul_."""
tensor = torch.randn(64)
ptr_before = tensor.data_ptr()
magnitude_pruning_(tensor, 0.5)
assert tensor.data_ptr() == ptr_before
tensor = torch.randn(64)
ptr_before = tensor.data_ptr()
random_pruning_(tensor, 0.5)
assert tensor.data_ptr() == ptr_before
def test_pruning_helpers_support_uint8_tensors():
"""Both pruning helpers must work on uint8 optimizer state tensors."""
tensor = torch.arange(1, 129, dtype=torch.uint8)
magnitude_pruning_(tensor, 0.9)
assert tensor.dtype == torch.uint8
magnitude_zero_frac = (tensor == 0).float().mean().item()
assert 0.85 <= magnitude_zero_frac <= 0.95
tensor = torch.arange(1, 129, dtype=torch.uint8)
with torch.random.fork_rng(devices=[]):
torch.manual_seed(1234)
random_pruning_(tensor, 0.9)
assert tensor.dtype == torch.uint8
random_zero_frac = (tensor == 0).float().mean().item()
assert 0.85 <= random_zero_frac <= 0.95