Fix Axolotl ReLoRA optimizer reset scope (#3646)
* Fix Axolotl ReLoRA optimizer reset scope * fix: make relora reset method honor relora_prune_ratio When relora_prune_method='reset' and relora_prune_ratio is explicitly set, the ratio was silently ignored and replaced with the hardcoded _FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset and 0.9 for other methods. reset_optimizer now uses the same random pruning path for both 'random' and 'reset'. Also consolidate three-layer default mismatch: schema default for relora_prune_method is now 'magnitude' (single canonical source); dataclass defaults for both fields changed to None to eliminate the conflicting fallback layer. Tests updated: removed the test case that verified the old broken behavior (reset ignoring ratio), added two cases proving reset honors the passed ratio. E2E reset fixture now uses ratio=0.5 to make it unambiguous that the ratio is honored. * Fix ReLoRA uint8 pruning regression --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
|||||||
exclude:
|
exclude:
|
||||||
- python_version: "3.14"
|
- python_version: "3.14"
|
||||||
pytorch_version: "2.9.1"
|
pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 25
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: cleanup node
|
- name: cleanup node
|
||||||
|
|||||||
@@ -286,10 +286,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||||
if self.cfg.relora_prune_ratio:
|
if self.cfg.relora_prune_ratio is not None:
|
||||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||||
self.cfg.relora_prune_ratio
|
self.cfg.relora_prune_ratio
|
||||||
)
|
)
|
||||||
|
if self.cfg.relora_prune_method:
|
||||||
|
training_arguments_kwargs["relora_prune_method"] = (
|
||||||
|
self.cfg.relora_prune_method
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.jagged_restart_steps:
|
if self.cfg.jagged_restart_steps:
|
||||||
training_arguments_kwargs["jagged_restart_steps"] = (
|
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||||
|
|||||||
@@ -83,13 +83,18 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for data processing"},
|
metadata={"help": "The number of processes to use for data processing"},
|
||||||
)
|
)
|
||||||
relora_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_prune_ratio: Optional[float] = field(
|
relora_prune_ratio: Optional[float] = field(
|
||||||
default=0.9,
|
default=None,
|
||||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"prune ratio for optimizer state pruning; "
|
||||||
|
"defaults to 0.999 for reset method, 0.9 for others"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
relora_prune_method: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "optimizer state pruning method: magnitude | random | reset"},
|
||||||
)
|
)
|
||||||
jagged_restart_steps: Optional[int] = field(
|
jagged_restart_steps: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -6,9 +6,8 @@ import os.path
|
|||||||
import shutil
|
import shutil
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
import peft
|
import peft
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
@@ -28,9 +27,15 @@ from axolotl.utils.logging import get_logger
|
|||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
except ImportError: # pragma: no cover - optional dependency for 8-bit merge paths
|
||||||
|
bnb = None
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def magnitude_pruning_(tensor, prune_ratio):
|
def magnitude_pruning_(tensor, prune_ratio):
|
||||||
|
"""Zero the lowest ``prune_ratio`` fraction of values by absolute magnitude, in place."""
|
||||||
tensor_magnitude = torch.abs(tensor)
|
tensor_magnitude = torch.abs(tensor)
|
||||||
threshold = torch.quantile(
|
threshold = torch.quantile(
|
||||||
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
||||||
@@ -40,15 +45,43 @@ def magnitude_pruning_(tensor, prune_ratio):
|
|||||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def random_pruning_(tensor, prune_ratio):
|
||||||
|
"""Zero a random ``prune_ratio`` fraction of values, in place."""
|
||||||
|
mask = (
|
||||||
|
torch.rand(tensor.shape, dtype=torch.float32, device=tensor.device)
|
||||||
|
> prune_ratio
|
||||||
|
)
|
||||||
|
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||||
|
|
||||||
|
|
||||||
|
# 0.999 mirrors the reference implementation. True zeroing breaks
|
||||||
|
# ZeroRedundancyOptimizer.consolidate_state_dict; see Guitaricet/relora's
|
||||||
|
# peft_pretraining/training_utils.py for the original note on this.
|
||||||
|
_FULL_RESET_RATIO = 0.999
|
||||||
|
|
||||||
|
|
||||||
def reset_optimizer(
|
def reset_optimizer(
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
*,
|
*,
|
||||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
reset_params: List[torch.nn.Parameter],
|
||||||
optimizer_state_keys: List[str],
|
optimizer_state_keys: List[str],
|
||||||
optimizer_magnitude_pruning: float = 0.9,
|
prune_method: Literal["magnitude", "random", "reset"] = "magnitude",
|
||||||
|
prune_ratio: float = 0.9,
|
||||||
):
|
):
|
||||||
# pylint:disable=unused-argument
|
"""Prune optimizer state for ``reset_params`` only."""
|
||||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
|
if prune_method == "magnitude":
|
||||||
|
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||||
|
elif prune_method in ("random", "reset"):
|
||||||
|
# "reset" is random pruning at a near-full ratio; the caller is responsible
|
||||||
|
# for supplying the appropriate prune_ratio (see ReLoRACallback.on_step_begin).
|
||||||
|
pruning_fn = partial(random_pruning_, prune_ratio=prune_ratio)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown prune_method {prune_method!r}; expected one of "
|
||||||
|
"'magnitude', 'random', 'reset'"
|
||||||
|
)
|
||||||
|
|
||||||
n_zeros = 0
|
n_zeros = 0
|
||||||
n_total = 0
|
n_total = 0
|
||||||
|
|
||||||
@@ -56,22 +89,22 @@ def reset_optimizer(
|
|||||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||||
optimizer_state = optimizer.optim.state
|
optimizer_state = optimizer.optim.state
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for param in reset_params:
|
||||||
for param in group["params"]:
|
state = optimizer_state.get(param, {})
|
||||||
state = optimizer_state[param]
|
if not state:
|
||||||
for key, value in state.items():
|
continue
|
||||||
if key not in optimizer_state_keys:
|
for key in optimizer_state_keys:
|
||||||
|
value = state.get(key)
|
||||||
|
if value is None or not torch.is_tensor(value):
|
||||||
continue
|
continue
|
||||||
if torch.is_tensor(value):
|
|
||||||
try:
|
try:
|
||||||
pruning_fn(value)
|
pruning_fn(value)
|
||||||
n_total += value.numel()
|
n_total += value.numel()
|
||||||
n_zeros += torch.sum(value == 0).item()
|
n_zeros += torch.sum(value == 0).item()
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
if "quantile() input tensor is too large" in str(exc):
|
if "quantile() input tensor is too large" in str(exc):
|
||||||
pass
|
continue
|
||||||
else:
|
raise
|
||||||
raise exc
|
|
||||||
|
|
||||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||||
@@ -82,11 +115,12 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||||
|
|
||||||
def __init__(self, cfg: DictDefault):
|
def __init__(self, cfg: DictDefault):
|
||||||
self.relora_steps = cfg.jagged_restart_steps
|
self.jagged_restart_steps = cfg.jagged_restart_steps
|
||||||
self.cpu_offload = cfg.relora_cpu_offload
|
self.cpu_offload = cfg.relora_cpu_offload
|
||||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||||
self.last_full_model = cfg.base_model
|
self.last_full_model = cfg.base_model
|
||||||
self.resume_from_checkpoint = cfg.resume_from_checkpoint
|
self.resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
self.prune_method = cfg.relora_prune_method or "magnitude"
|
||||||
|
|
||||||
if not os.path.exists(self.last_full_model):
|
if not os.path.exists(self.last_full_model):
|
||||||
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
|
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
|
||||||
@@ -128,7 +162,7 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if not optimizer:
|
if not optimizer:
|
||||||
optimizer = state.optimizer
|
optimizer = state.optimizer
|
||||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
if state.global_step > 0 and state.global_step % self.jagged_restart_steps == 0:
|
||||||
checkpoint_folder = os.path.join(
|
checkpoint_folder = os.path.join(
|
||||||
args.output_dir,
|
args.output_dir,
|
||||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||||
@@ -144,7 +178,7 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||||
|
|
||||||
lora_params = [
|
lora_params = [
|
||||||
n
|
p
|
||||||
for n, p in model.named_parameters()
|
for n, p in model.named_parameters()
|
||||||
if p.requires_grad and "lora_" in n
|
if p.requires_grad and "lora_" in n
|
||||||
]
|
]
|
||||||
@@ -166,11 +200,19 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
actually_save=is_main_process(),
|
actually_save=is_main_process(),
|
||||||
cpu_offload=self.cpu_offload,
|
cpu_offload=self.cpu_offload,
|
||||||
)
|
)
|
||||||
|
# When relora_prune_ratio is not set, use _FULL_RESET_RATIO for
|
||||||
|
# "reset" (paper-style near-full reset) and 0.9 for other methods.
|
||||||
|
prune_ratio = args.relora_prune_ratio
|
||||||
|
if prune_ratio is None:
|
||||||
|
prune_ratio = (
|
||||||
|
_FULL_RESET_RATIO if self.prune_method == "reset" else 0.9
|
||||||
|
)
|
||||||
reset_optimizer(
|
reset_optimizer(
|
||||||
optimizer,
|
optimizer,
|
||||||
reset_params=lora_params,
|
reset_params=lora_params,
|
||||||
optimizer_state_keys=optimizer_state_keys,
|
optimizer_state_keys=optimizer_state_keys,
|
||||||
optimizer_magnitude_pruning=args.relora_prune_ratio,
|
prune_method=self.prune_method,
|
||||||
|
prune_ratio=prune_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
@@ -191,8 +233,8 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
|
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
state.global_step >= self.relora_steps
|
state.global_step >= self.jagged_restart_steps
|
||||||
and state.global_step % self.relora_steps != 0
|
and state.global_step % self.jagged_restart_steps != 0
|
||||||
):
|
):
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
if is_main_process() and self.last_full_model != checkpoint_folder:
|
if is_main_process() and self.last_full_model != checkpoint_folder:
|
||||||
@@ -320,6 +362,8 @@ def update_weights(
|
|||||||
target.weight.data = new_weight.cpu()
|
target.weight.data = new_weight.cpu()
|
||||||
target.to(device)
|
target.to(device)
|
||||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||||
|
if bnb is None:
|
||||||
|
raise ImportError("bitsandbytes is required to merge 8-bit LoRA weights")
|
||||||
target.weight.data = (
|
target.weight.data = (
|
||||||
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -240,8 +240,28 @@ class ReLoRAConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
relora_prune_ratio: float | None = Field(
|
relora_prune_ratio: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "threshold for optimizer magnitude when pruning"
|
"description": (
|
||||||
|
"Fraction of optimizer state values to zero on each ReLoRA restart. "
|
||||||
|
"When relora_prune_method='reset' and this is omitted, defaults to "
|
||||||
|
"0.999 (paper-style near-full reset). For other methods, defaults to 0.9."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
relora_prune_method: Literal["magnitude", "random", "reset"] | None = Field(
|
||||||
|
default="magnitude",
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"Optimizer state pruning method on each ReLoRA restart. "
|
||||||
|
"'magnitude' (default) keeps top-k by absolute value; "
|
||||||
|
"'random' keeps a random subset at relora_prune_ratio; "
|
||||||
|
"'reset' uses near-full random pruning (default ratio 0.999, "
|
||||||
|
"honoring relora_prune_ratio when explicitly set). "
|
||||||
|
"Paper-style recipe: relora_prune_method='reset' with no "
|
||||||
|
"relora_prune_ratio, equivalent to 'random' with ratio=0.999."
|
||||||
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
relora_cpu_offload: bool | None = Field(
|
relora_cpu_offload: bool | None = Field(
|
||||||
|
|||||||
@@ -56,7 +56,72 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
"warmup_steps": 10,
|
"warmup_steps": 10,
|
||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"max_steps": 105, # at least 2x relora_steps
|
"max_steps": 105, # at least 2x restart cadence
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"save_first_step": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
dataset_meta = load_datasets(cfg=cfg)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||||
|
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
|
||||||
|
"Relora model checkpoint not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_relora_reset_method(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"flash_attention": True,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_modules": ["q_proj", "v_proj"],
|
||||||
|
"relora": True,
|
||||||
|
"jagged_restart_steps": 50,
|
||||||
|
"jagged_restart_warmup_steps": 10,
|
||||||
|
"jagged_restart_anneal_steps": 10,
|
||||||
|
"relora_prune_ratio": 0.5, # explicitly honored by reset (not ignored)
|
||||||
|
"relora_prune_method": "reset",
|
||||||
|
"relora_cpu_offload": True,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mlabonne/FineTome-100k",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"warmup_steps": 10,
|
||||||
|
"num_epochs": 2,
|
||||||
|
"max_steps": 105,
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
|||||||
186
tests/monkeypatch/test_relora.py
Normal file
186
tests/monkeypatch/test_relora.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""Unit tests for axolotl.monkeypatch.relora.reset_optimizer."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.relora import (
|
||||||
|
magnitude_pruning_,
|
||||||
|
random_pruning_,
|
||||||
|
reset_optimizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
ADAM_KEYS = ["exp_avg", "exp_avg_sq"]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_optimizer_with_state(seed: int = 0):
|
||||||
|
"""Build a tiny optimizer over LoRA-shaped + non-LoRA params with populated state."""
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
lora_a = nn.Parameter(torch.randn(8, 32))
|
||||||
|
lora_b = nn.Parameter(torch.randn(32, 8))
|
||||||
|
extra = nn.Parameter(torch.randn(64, 32))
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW([lora_a, lora_b, extra], lr=1e-3)
|
||||||
|
for _ in range(2):
|
||||||
|
loss = (
|
||||||
|
(lora_a * torch.randn_like(lora_a)).sum()
|
||||||
|
+ (lora_b * torch.randn_like(lora_b)).sum()
|
||||||
|
+ (extra * torch.randn_like(extra)).sum()
|
||||||
|
)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
return optimizer, lora_a, lora_b, extra
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_optimizer_only_touches_reset_params():
|
||||||
|
"""State for params NOT in reset_params must be byte-identical after reset."""
|
||||||
|
optimizer, lora_a, lora_b, extra = _build_optimizer_with_state()
|
||||||
|
|
||||||
|
extra_avg_before = optimizer.state[extra]["exp_avg"].clone()
|
||||||
|
extra_avg_sq_before = optimizer.state[extra]["exp_avg_sq"].clone()
|
||||||
|
|
||||||
|
reset_optimizer(
|
||||||
|
optimizer,
|
||||||
|
reset_params=[lora_a, lora_b],
|
||||||
|
optimizer_state_keys=ADAM_KEYS,
|
||||||
|
prune_method="magnitude",
|
||||||
|
prune_ratio=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.equal(optimizer.state[extra]["exp_avg"], extra_avg_before)
|
||||||
|
assert torch.equal(optimizer.state[extra]["exp_avg_sq"], extra_avg_sq_before)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_optimizer_actually_prunes_lora_state():
|
||||||
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||||
|
|
||||||
|
reset_optimizer(
|
||||||
|
optimizer,
|
||||||
|
reset_params=[lora_a, lora_b],
|
||||||
|
optimizer_state_keys=ADAM_KEYS,
|
||||||
|
prune_method="magnitude",
|
||||||
|
prune_ratio=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
for param in (lora_a, lora_b):
|
||||||
|
for key in ADAM_KEYS:
|
||||||
|
zero_frac = (optimizer.state[param][key] == 0).float().mean().item()
|
||||||
|
assert zero_frac >= 0.85
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"method,ratio,expected_zero_frac",
|
||||||
|
[
|
||||||
|
("magnitude", 0.9, 0.9),
|
||||||
|
("magnitude", 0.99, 0.99),
|
||||||
|
("random", 0.9, 0.9),
|
||||||
|
("random", 0.5, 0.5),
|
||||||
|
# reset uses random pruning; relora_prune_ratio must be honored, not ignored.
|
||||||
|
("reset", 0.9, 0.9),
|
||||||
|
("reset", 0.5, 0.5),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_prune_methods(method, ratio, expected_zero_frac):
|
||||||
|
"""Each method zeros approximately the expected fraction."""
|
||||||
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state(seed=42)
|
||||||
|
|
||||||
|
reset_optimizer(
|
||||||
|
optimizer,
|
||||||
|
reset_params=[lora_a, lora_b],
|
||||||
|
optimizer_state_keys=ADAM_KEYS,
|
||||||
|
prune_method=method,
|
||||||
|
prune_ratio=ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
zeros = 0
|
||||||
|
for param in (lora_a, lora_b):
|
||||||
|
for key in ADAM_KEYS:
|
||||||
|
tensor = optimizer.state[param][key]
|
||||||
|
total += tensor.numel()
|
||||||
|
zeros += (tensor == 0).sum().item()
|
||||||
|
|
||||||
|
actual = zeros / total
|
||||||
|
tolerance = 0.02 if method == "magnitude" else 0.05
|
||||||
|
assert math.isclose(actual, expected_zero_frac, abs_tol=tolerance)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_optimizer_skips_keys_not_in_state_keys():
|
||||||
|
"""Keys present in optimizer state but not in optimizer_state_keys are untouched."""
|
||||||
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||||
|
|
||||||
|
exp_avg_sq_before = optimizer.state[lora_a]["exp_avg_sq"].clone()
|
||||||
|
|
||||||
|
reset_optimizer(
|
||||||
|
optimizer,
|
||||||
|
reset_params=[lora_a, lora_b],
|
||||||
|
optimizer_state_keys=["exp_avg"],
|
||||||
|
prune_method="magnitude",
|
||||||
|
prune_ratio=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert torch.equal(optimizer.state[lora_a]["exp_avg_sq"], exp_avg_sq_before)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_optimizer_handles_param_with_empty_state():
|
||||||
|
"""Params with no optimizer state are skipped silently."""
|
||||||
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||||
|
orphan = nn.Parameter(torch.randn(4, 4))
|
||||||
|
|
||||||
|
reset_optimizer(
|
||||||
|
optimizer,
|
||||||
|
reset_params=[lora_a, lora_b, orphan],
|
||||||
|
optimizer_state_keys=ADAM_KEYS,
|
||||||
|
prune_method="magnitude",
|
||||||
|
prune_ratio=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert orphan not in optimizer.state or not optimizer.state[orphan]
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_prune_method_raises():
|
||||||
|
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown prune_method"):
|
||||||
|
reset_optimizer(
|
||||||
|
optimizer,
|
||||||
|
reset_params=[lora_a, lora_b],
|
||||||
|
optimizer_state_keys=ADAM_KEYS,
|
||||||
|
prune_method="bogus", # type: ignore[arg-type]
|
||||||
|
prune_ratio=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pruning_helpers_are_inplace():
|
||||||
|
"""magnitude_pruning_ and random_pruning_ must mutate via tensor.mul_."""
|
||||||
|
tensor = torch.randn(64)
|
||||||
|
ptr_before = tensor.data_ptr()
|
||||||
|
magnitude_pruning_(tensor, 0.5)
|
||||||
|
assert tensor.data_ptr() == ptr_before
|
||||||
|
|
||||||
|
tensor = torch.randn(64)
|
||||||
|
ptr_before = tensor.data_ptr()
|
||||||
|
random_pruning_(tensor, 0.5)
|
||||||
|
assert tensor.data_ptr() == ptr_before
|
||||||
|
|
||||||
|
|
||||||
|
def test_pruning_helpers_support_uint8_tensors():
|
||||||
|
"""Both pruning helpers must work on uint8 optimizer state tensors."""
|
||||||
|
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
||||||
|
magnitude_pruning_(tensor, 0.9)
|
||||||
|
|
||||||
|
assert tensor.dtype == torch.uint8
|
||||||
|
magnitude_zero_frac = (tensor == 0).float().mean().item()
|
||||||
|
assert 0.85 <= magnitude_zero_frac <= 0.95
|
||||||
|
|
||||||
|
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
||||||
|
with torch.random.fork_rng(devices=[]):
|
||||||
|
torch.manual_seed(1234)
|
||||||
|
random_pruning_(tensor, 0.9)
|
||||||
|
|
||||||
|
assert tensor.dtype == torch.uint8
|
||||||
|
random_zero_frac = (tensor == 0).float().mean().item()
|
||||||
|
assert 0.85 <= random_zero_frac <= 0.95
|
||||||
Reference in New Issue
Block a user