From e2f01de0e8ebf88cabea7281975bfcade8080693 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 9 May 2026 17:52:35 -0400 Subject: [PATCH] Fix Axolotl ReLoRA optimizer reset scope (#3646) * Fix Axolotl ReLoRA optimizer reset scope * fix: make relora reset method honor relora_prune_ratio When relora_prune_method='reset' and relora_prune_ratio is explicitly set, the ratio was silently ignored and replaced with the hardcoded _FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset and 0.9 for other methods. reset_optimizer now uses the same random pruning path for both 'random' and 'reset'. Also consolidate three-layer default mismatch: schema default for relora_prune_method is now 'magnitude' (single canonical source); dataclass defaults for both fields changed to None to eliminate the conflicting fallback layer. Tests updated: removed the test case that verified the old broken behavior (reset ignoring ratio), added two cases proving reset honors the passed ratio. E2E reset fixture now uses ratio=0.5 to make it unambiguous that the ratio is honored. * Fix ReLoRA uint8 pruning regression --------- Signed-off-by: Wing Lian Co-authored-by: Axolotl Swarm --- .github/workflows/tests.yml | 2 +- src/axolotl/core/builders/causal.py | 6 +- src/axolotl/core/training_args_base.py | 17 ++- src/axolotl/monkeypatch/relora.py | 98 +++++++++---- src/axolotl/utils/schemas/peft.py | 22 ++- tests/e2e/solo/test_relora_llama.py | 67 ++++++++- tests/monkeypatch/test_relora.py | 186 +++++++++++++++++++++++++ 7 files changed, 361 insertions(+), 37 deletions(-) create mode 100644 tests/monkeypatch/test_relora.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e21e60ab5..6b298ade0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,7 +72,7 @@ jobs: exclude: - python_version: "3.14" pytorch_version: "2.9.1" - timeout-minutes: 20 + timeout-minutes: 25 steps: - name: cleanup node diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index aa1678523..b5f365bce 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -286,10 +286,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ) if self.cfg.relora and self.cfg.jagged_restart_steps: - if self.cfg.relora_prune_ratio: + if self.cfg.relora_prune_ratio is not None: training_arguments_kwargs["relora_prune_ratio"] = ( self.cfg.relora_prune_ratio ) + if self.cfg.relora_prune_method: + training_arguments_kwargs["relora_prune_method"] = ( + self.cfg.relora_prune_method + ) if self.cfg.jagged_restart_steps: training_arguments_kwargs["jagged_restart_steps"] = ( diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 427a80a46..e9727dbb1 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -83,13 +83,18 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "The number of processes to use for data processing"}, ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) relora_prune_ratio: Optional[float] = field( - default=0.9, - metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + default=None, + metadata={ + "help": ( + "prune ratio for optimizer state pruning; " + "defaults to 0.999 for reset method, 0.9 for others" + ) + }, + ) + relora_prune_method: Optional[str] = field( + default=None, + metadata={"help": "optimizer state pruning method: magnitude | random | reset"}, ) jagged_restart_steps: Optional[int] = field( default=None, diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index cf93c32dd..3d33ab204 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -6,9 +6,8 @@ import os.path import shutil from functools import partial from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List, Literal, Union -import bitsandbytes as bnb import peft import safetensors.torch as st import torch @@ -28,9 +27,15 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +try: + import bitsandbytes as bnb +except ImportError: # pragma: no cover - optional dependency for 8-bit merge paths + bnb = None + @torch.no_grad() def magnitude_pruning_(tensor, prune_ratio): + """Zero the lowest ``prune_ratio`` fraction of values by absolute magnitude, in place.""" tensor_magnitude = torch.abs(tensor) threshold = torch.quantile( tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio @@ -40,15 +45,43 @@ def magnitude_pruning_(tensor, prune_ratio): tensor.mul_(mask.to(dtype=tensor.dtype)) +@torch.no_grad() +def random_pruning_(tensor, prune_ratio): + """Zero a random ``prune_ratio`` fraction of values, in place.""" + mask = ( + torch.rand(tensor.shape, dtype=torch.float32, device=tensor.device) + > prune_ratio + ) + tensor.mul_(mask.to(dtype=tensor.dtype)) + + +# 0.999 mirrors the reference implementation. True zeroing breaks +# ZeroRedundancyOptimizer.consolidate_state_dict; see Guitaricet/relora's +# peft_pretraining/training_utils.py for the original note on this. +_FULL_RESET_RATIO = 0.999 + + def reset_optimizer( optimizer: torch.optim.Optimizer, *, - reset_params: List[str], # where str is the key to a torch.nn.Parameter + reset_params: List[torch.nn.Parameter], optimizer_state_keys: List[str], - optimizer_magnitude_pruning: float = 0.9, + prune_method: Literal["magnitude", "random", "reset"] = "magnitude", + prune_ratio: float = 0.9, ): - # pylint:disable=unused-argument - pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning) + """Prune optimizer state for ``reset_params`` only.""" + if prune_method == "magnitude": + pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio) + elif prune_method in ("random", "reset"): + # "reset" is random pruning at a near-full ratio; the caller is responsible + # for supplying the appropriate prune_ratio (see ReLoRACallback.on_step_begin). + pruning_fn = partial(random_pruning_, prune_ratio=prune_ratio) + else: + raise ValueError( + f"Unknown prune_method {prune_method!r}; expected one of " + "'magnitude', 'random', 'reset'" + ) + n_zeros = 0 n_total = 0 @@ -56,22 +89,22 @@ def reset_optimizer( if isinstance(optimizer, ZeroRedundancyOptimizer): optimizer_state = optimizer.optim.state - for group in optimizer.param_groups: - for param in group["params"]: - state = optimizer_state[param] - for key, value in state.items(): - if key not in optimizer_state_keys: + for param in reset_params: + state = optimizer_state.get(param, {}) + if not state: + continue + for key in optimizer_state_keys: + value = state.get(key) + if value is None or not torch.is_tensor(value): + continue + try: + pruning_fn(value) + n_total += value.numel() + n_zeros += torch.sum(value == 0).item() + except RuntimeError as exc: + if "quantile() input tensor is too large" in str(exc): continue - if torch.is_tensor(value): - try: - pruning_fn(value) - n_total += value.numel() - n_zeros += torch.sum(value == 0).item() - except RuntimeError as exc: - if "quantile() input tensor is too large" in str(exc): - pass - else: - raise exc + raise _zeroed = n_zeros / (1e-7 + n_total) * 100 LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}") @@ -82,11 +115,12 @@ class ReLoRACallback(TrainerCallback): """Callback to merge LoRA weights into the base model and save full-weight checkpoints""" def __init__(self, cfg: DictDefault): - self.relora_steps = cfg.jagged_restart_steps + self.jagged_restart_steps = cfg.jagged_restart_steps self.cpu_offload = cfg.relora_cpu_offload self.quantized = cfg.load_in_4bit or cfg.load_in_8bit self.last_full_model = cfg.base_model self.resume_from_checkpoint = cfg.resume_from_checkpoint + self.prune_method = cfg.relora_prune_method or "magnitude" if not os.path.exists(self.last_full_model): self.last_full_model = str(Path(snapshot_download(cfg.base_model))) @@ -128,7 +162,7 @@ class ReLoRACallback(TrainerCallback): ): if not optimizer: optimizer = state.optimizer - if state.global_step > 0 and state.global_step % self.relora_steps == 0: + if state.global_step > 0 and state.global_step % self.jagged_restart_steps == 0: checkpoint_folder = os.path.join( args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", @@ -144,7 +178,7 @@ class ReLoRACallback(TrainerCallback): raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA") lora_params = [ - n + p for n, p in model.named_parameters() if p.requires_grad and "lora_" in n ] @@ -166,11 +200,19 @@ class ReLoRACallback(TrainerCallback): actually_save=is_main_process(), cpu_offload=self.cpu_offload, ) + # When relora_prune_ratio is not set, use _FULL_RESET_RATIO for + # "reset" (paper-style near-full reset) and 0.9 for other methods. + prune_ratio = args.relora_prune_ratio + if prune_ratio is None: + prune_ratio = ( + _FULL_RESET_RATIO if self.prune_method == "reset" else 0.9 + ) reset_optimizer( optimizer, reset_params=lora_params, optimizer_state_keys=optimizer_state_keys, - optimizer_magnitude_pruning=args.relora_prune_ratio, + prune_method=self.prune_method, + prune_ratio=prune_ratio, ) if self.quantized: @@ -191,8 +233,8 @@ class ReLoRACallback(TrainerCallback): args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora" ) if ( - state.global_step >= self.relora_steps - and state.global_step % self.relora_steps != 0 + state.global_step >= self.jagged_restart_steps + and state.global_step % self.jagged_restart_steps != 0 ): if self.quantized: if is_main_process() and self.last_full_model != checkpoint_folder: @@ -320,6 +362,8 @@ def update_weights( target.weight.data = new_weight.cpu() target.to(device) elif isinstance(target, peft.tuners.lora.Linear8bitLt): + if bnb is None: + raise ImportError("bitsandbytes is required to merge 8-bit LoRA weights") target.weight.data = ( bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data ) diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index c60c548f0..42fa628e0 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -240,8 +240,28 @@ class ReLoRAConfig(BaseModel): ) relora_prune_ratio: float | None = Field( default=None, + ge=0.0, + le=1.0, json_schema_extra={ - "description": "threshold for optimizer magnitude when pruning" + "description": ( + "Fraction of optimizer state values to zero on each ReLoRA restart. " + "When relora_prune_method='reset' and this is omitted, defaults to " + "0.999 (paper-style near-full reset). For other methods, defaults to 0.9." + ) + }, + ) + relora_prune_method: Literal["magnitude", "random", "reset"] | None = Field( + default="magnitude", + json_schema_extra={ + "description": ( + "Optimizer state pruning method on each ReLoRA restart. " + "'magnitude' (default) keeps top-k by absolute value; " + "'random' keeps a random subset at relora_prune_ratio; " + "'reset' uses near-full random pruning (default ratio 0.999, " + "honoring relora_prune_ratio when explicitly set). " + "Paper-style recipe: relora_prune_method='reset' with no " + "relora_prune_ratio, equivalent to 'random' with ratio=0.999." + ) }, ) relora_cpu_offload: bool | None = Field( diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index 091bb90c6..895f32d99 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -56,7 +56,72 @@ class TestReLoraLlama(unittest.TestCase): ], "warmup_steps": 10, "num_epochs": 2, - "max_steps": 105, # at least 2x relora_steps + "max_steps": 105, # at least 2x restart cadence + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "use_tensorboard": True, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg) + assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), ( + "Relora model checkpoint not found" + ) + + check_tensorboard( + temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high" + ) + + @with_temp_dir + def test_relora_reset_method(self, temp_dir): + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 2048, + "sample_packing": True, + "pad_to_sequence_len": True, + "flash_attention": True, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_modules": ["q_proj", "v_proj"], + "relora": True, + "jagged_restart_steps": 50, + "jagged_restart_warmup_steps": 10, + "jagged_restart_anneal_steps": 10, + "relora_prune_ratio": 0.5, # explicitly honored by reset (not ignored) + "relora_prune_method": "reset", + "relora_cpu_offload": True, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "chat_template": "chatml", + "datasets": [ + { + "path": "mlabonne/FineTome-100k", + "type": "chat_template", + "split": "train[:10%]", + "field_messages": "conversations", + "message_field_role": "from", + "message_field_content": "value", + }, + ], + "warmup_steps": 10, + "num_epochs": 2, + "max_steps": 105, "micro_batch_size": 2, "gradient_accumulation_steps": 1, "output_dir": temp_dir, diff --git a/tests/monkeypatch/test_relora.py b/tests/monkeypatch/test_relora.py new file mode 100644 index 000000000..e722adf0c --- /dev/null +++ b/tests/monkeypatch/test_relora.py @@ -0,0 +1,186 @@ +"""Unit tests for axolotl.monkeypatch.relora.reset_optimizer.""" + +import math + +import pytest +import torch +import torch.nn as nn + +from axolotl.monkeypatch.relora import ( + magnitude_pruning_, + random_pruning_, + reset_optimizer, +) + +ADAM_KEYS = ["exp_avg", "exp_avg_sq"] + + +def _build_optimizer_with_state(seed: int = 0): + """Build a tiny optimizer over LoRA-shaped + non-LoRA params with populated state.""" + torch.manual_seed(seed) + lora_a = nn.Parameter(torch.randn(8, 32)) + lora_b = nn.Parameter(torch.randn(32, 8)) + extra = nn.Parameter(torch.randn(64, 32)) + + optimizer = torch.optim.AdamW([lora_a, lora_b, extra], lr=1e-3) + for _ in range(2): + loss = ( + (lora_a * torch.randn_like(lora_a)).sum() + + (lora_b * torch.randn_like(lora_b)).sum() + + (extra * torch.randn_like(extra)).sum() + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + return optimizer, lora_a, lora_b, extra + + +def test_reset_optimizer_only_touches_reset_params(): + """State for params NOT in reset_params must be byte-identical after reset.""" + optimizer, lora_a, lora_b, extra = _build_optimizer_with_state() + + extra_avg_before = optimizer.state[extra]["exp_avg"].clone() + extra_avg_sq_before = optimizer.state[extra]["exp_avg_sq"].clone() + + reset_optimizer( + optimizer, + reset_params=[lora_a, lora_b], + optimizer_state_keys=ADAM_KEYS, + prune_method="magnitude", + prune_ratio=0.9, + ) + + assert torch.equal(optimizer.state[extra]["exp_avg"], extra_avg_before) + assert torch.equal(optimizer.state[extra]["exp_avg_sq"], extra_avg_sq_before) + + +def test_reset_optimizer_actually_prunes_lora_state(): + optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state() + + reset_optimizer( + optimizer, + reset_params=[lora_a, lora_b], + optimizer_state_keys=ADAM_KEYS, + prune_method="magnitude", + prune_ratio=0.9, + ) + + for param in (lora_a, lora_b): + for key in ADAM_KEYS: + zero_frac = (optimizer.state[param][key] == 0).float().mean().item() + assert zero_frac >= 0.85 + + +@pytest.mark.parametrize( + "method,ratio,expected_zero_frac", + [ + ("magnitude", 0.9, 0.9), + ("magnitude", 0.99, 0.99), + ("random", 0.9, 0.9), + ("random", 0.5, 0.5), + # reset uses random pruning; relora_prune_ratio must be honored, not ignored. + ("reset", 0.9, 0.9), + ("reset", 0.5, 0.5), + ], +) +def test_prune_methods(method, ratio, expected_zero_frac): + """Each method zeros approximately the expected fraction.""" + optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state(seed=42) + + reset_optimizer( + optimizer, + reset_params=[lora_a, lora_b], + optimizer_state_keys=ADAM_KEYS, + prune_method=method, + prune_ratio=ratio, + ) + + total = 0 + zeros = 0 + for param in (lora_a, lora_b): + for key in ADAM_KEYS: + tensor = optimizer.state[param][key] + total += tensor.numel() + zeros += (tensor == 0).sum().item() + + actual = zeros / total + tolerance = 0.02 if method == "magnitude" else 0.05 + assert math.isclose(actual, expected_zero_frac, abs_tol=tolerance) + + +def test_reset_optimizer_skips_keys_not_in_state_keys(): + """Keys present in optimizer state but not in optimizer_state_keys are untouched.""" + optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state() + + exp_avg_sq_before = optimizer.state[lora_a]["exp_avg_sq"].clone() + + reset_optimizer( + optimizer, + reset_params=[lora_a, lora_b], + optimizer_state_keys=["exp_avg"], + prune_method="magnitude", + prune_ratio=0.9, + ) + + assert torch.equal(optimizer.state[lora_a]["exp_avg_sq"], exp_avg_sq_before) + + +def test_reset_optimizer_handles_param_with_empty_state(): + """Params with no optimizer state are skipped silently.""" + optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state() + orphan = nn.Parameter(torch.randn(4, 4)) + + reset_optimizer( + optimizer, + reset_params=[lora_a, lora_b, orphan], + optimizer_state_keys=ADAM_KEYS, + prune_method="magnitude", + prune_ratio=0.9, + ) + + assert orphan not in optimizer.state or not optimizer.state[orphan] + + +def test_unknown_prune_method_raises(): + optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state() + + with pytest.raises(ValueError, match="Unknown prune_method"): + reset_optimizer( + optimizer, + reset_params=[lora_a, lora_b], + optimizer_state_keys=ADAM_KEYS, + prune_method="bogus", # type: ignore[arg-type] + prune_ratio=0.9, + ) + + +def test_pruning_helpers_are_inplace(): + """magnitude_pruning_ and random_pruning_ must mutate via tensor.mul_.""" + tensor = torch.randn(64) + ptr_before = tensor.data_ptr() + magnitude_pruning_(tensor, 0.5) + assert tensor.data_ptr() == ptr_before + + tensor = torch.randn(64) + ptr_before = tensor.data_ptr() + random_pruning_(tensor, 0.5) + assert tensor.data_ptr() == ptr_before + + +def test_pruning_helpers_support_uint8_tensors(): + """Both pruning helpers must work on uint8 optimizer state tensors.""" + tensor = torch.arange(1, 129, dtype=torch.uint8) + magnitude_pruning_(tensor, 0.9) + + assert tensor.dtype == torch.uint8 + magnitude_zero_frac = (tensor == 0).float().mean().item() + assert 0.85 <= magnitude_zero_frac <= 0.95 + + tensor = torch.arange(1, 129, dtype=torch.uint8) + with torch.random.fork_rng(devices=[]): + torch.manual_seed(1234) + random_pruning_(tensor, 0.9) + + assert tensor.dtype == torch.uint8 + random_zero_frac = (tensor == 0).float().mean().item() + assert 0.85 <= random_zero_frac <= 0.95