Fix Axolotl ReLoRA optimizer reset scope (#3646)
* Fix Axolotl ReLoRA optimizer reset scope * fix: make relora reset method honor relora_prune_ratio When relora_prune_method='reset' and relora_prune_ratio is explicitly set, the ratio was silently ignored and replaced with the hardcoded _FULL_RESET_RATIO (0.999). Fix by moving the default-ratio logic to ReLoRACallback.on_step_begin: None maps to _FULL_RESET_RATIO for reset and 0.9 for other methods. reset_optimizer now uses the same random pruning path for both 'random' and 'reset'. Also consolidate three-layer default mismatch: schema default for relora_prune_method is now 'magnitude' (single canonical source); dataclass defaults for both fields changed to None to eliminate the conflicting fallback layer. Tests updated: removed the test case that verified the old broken behavior (reset ignoring ratio), added two cases proving reset honors the passed ratio. E2E reset fixture now uses ratio=0.5 to make it unambiguous that the ratio is honored. * Fix ReLoRA uint8 pruning regression --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
exclude:
|
||||
- python_version: "3.14"
|
||||
pytorch_version: "2.9.1"
|
||||
timeout-minutes: 20
|
||||
timeout-minutes: 25
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
|
||||
@@ -286,10 +286,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||
if self.cfg.relora_prune_ratio:
|
||||
if self.cfg.relora_prune_ratio is not None:
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
if self.cfg.relora_prune_method:
|
||||
training_arguments_kwargs["relora_prune_method"] = (
|
||||
self.cfg.relora_prune_method
|
||||
)
|
||||
|
||||
if self.cfg.jagged_restart_steps:
|
||||
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||
|
||||
@@ -83,13 +83,18 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for data processing"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"prune ratio for optimizer state pruning; "
|
||||
"defaults to 0.999 for reset method, 0.9 for others"
|
||||
)
|
||||
},
|
||||
)
|
||||
relora_prune_method: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "optimizer state pruning method: magnitude | random | reset"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
|
||||
@@ -6,9 +6,8 @@ import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
@@ -28,9 +27,15 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError: # pragma: no cover - optional dependency for 8-bit merge paths
|
||||
bnb = None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def magnitude_pruning_(tensor, prune_ratio):
|
||||
"""Zero the lowest ``prune_ratio`` fraction of values by absolute magnitude, in place."""
|
||||
tensor_magnitude = torch.abs(tensor)
|
||||
threshold = torch.quantile(
|
||||
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
||||
@@ -40,15 +45,43 @@ def magnitude_pruning_(tensor, prune_ratio):
|
||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def random_pruning_(tensor, prune_ratio):
|
||||
"""Zero a random ``prune_ratio`` fraction of values, in place."""
|
||||
mask = (
|
||||
torch.rand(tensor.shape, dtype=torch.float32, device=tensor.device)
|
||||
> prune_ratio
|
||||
)
|
||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||
|
||||
|
||||
# 0.999 mirrors the reference implementation. True zeroing breaks
|
||||
# ZeroRedundancyOptimizer.consolidate_state_dict; see Guitaricet/relora's
|
||||
# peft_pretraining/training_utils.py for the original note on this.
|
||||
_FULL_RESET_RATIO = 0.999
|
||||
|
||||
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
reset_params: List[torch.nn.Parameter],
|
||||
optimizer_state_keys: List[str],
|
||||
optimizer_magnitude_pruning: float = 0.9,
|
||||
prune_method: Literal["magnitude", "random", "reset"] = "magnitude",
|
||||
prune_ratio: float = 0.9,
|
||||
):
|
||||
# pylint:disable=unused-argument
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
|
||||
"""Prune optimizer state for ``reset_params`` only."""
|
||||
if prune_method == "magnitude":
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
elif prune_method in ("random", "reset"):
|
||||
# "reset" is random pruning at a near-full ratio; the caller is responsible
|
||||
# for supplying the appropriate prune_ratio (see ReLoRACallback.on_step_begin).
|
||||
pruning_fn = partial(random_pruning_, prune_ratio=prune_ratio)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prune_method {prune_method!r}; expected one of "
|
||||
"'magnitude', 'random', 'reset'"
|
||||
)
|
||||
|
||||
n_zeros = 0
|
||||
n_total = 0
|
||||
|
||||
@@ -56,22 +89,22 @@ def reset_optimizer(
|
||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||
optimizer_state = optimizer.optim.state
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
state = optimizer_state[param]
|
||||
for key, value in state.items():
|
||||
if key not in optimizer_state_keys:
|
||||
for param in reset_params:
|
||||
state = optimizer_state.get(param, {})
|
||||
if not state:
|
||||
continue
|
||||
for key in optimizer_state_keys:
|
||||
value = state.get(key)
|
||||
if value is None or not torch.is_tensor(value):
|
||||
continue
|
||||
if torch.is_tensor(value):
|
||||
try:
|
||||
pruning_fn(value)
|
||||
n_total += value.numel()
|
||||
n_zeros += torch.sum(value == 0).item()
|
||||
except RuntimeError as exc:
|
||||
if "quantile() input tensor is too large" in str(exc):
|
||||
pass
|
||||
else:
|
||||
raise exc
|
||||
continue
|
||||
raise
|
||||
|
||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||
@@ -82,11 +115,12 @@ class ReLoRACallback(TrainerCallback):
|
||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.jagged_restart_steps
|
||||
self.jagged_restart_steps = cfg.jagged_restart_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
self.resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
self.prune_method = cfg.relora_prune_method or "magnitude"
|
||||
|
||||
if not os.path.exists(self.last_full_model):
|
||||
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
|
||||
@@ -128,7 +162,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
):
|
||||
if not optimizer:
|
||||
optimizer = state.optimizer
|
||||
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
||||
if state.global_step > 0 and state.global_step % self.jagged_restart_steps == 0:
|
||||
checkpoint_folder = os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
@@ -144,7 +178,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||
|
||||
lora_params = [
|
||||
n
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if p.requires_grad and "lora_" in n
|
||||
]
|
||||
@@ -166,11 +200,19 @@ class ReLoRACallback(TrainerCallback):
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
# When relora_prune_ratio is not set, use _FULL_RESET_RATIO for
|
||||
# "reset" (paper-style near-full reset) and 0.9 for other methods.
|
||||
prune_ratio = args.relora_prune_ratio
|
||||
if prune_ratio is None:
|
||||
prune_ratio = (
|
||||
_FULL_RESET_RATIO if self.prune_method == "reset" else 0.9
|
||||
)
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=lora_params,
|
||||
optimizer_state_keys=optimizer_state_keys,
|
||||
optimizer_magnitude_pruning=args.relora_prune_ratio,
|
||||
prune_method=self.prune_method,
|
||||
prune_ratio=prune_ratio,
|
||||
)
|
||||
|
||||
if self.quantized:
|
||||
@@ -191,8 +233,8 @@ class ReLoRACallback(TrainerCallback):
|
||||
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
|
||||
)
|
||||
if (
|
||||
state.global_step >= self.relora_steps
|
||||
and state.global_step % self.relora_steps != 0
|
||||
state.global_step >= self.jagged_restart_steps
|
||||
and state.global_step % self.jagged_restart_steps != 0
|
||||
):
|
||||
if self.quantized:
|
||||
if is_main_process() and self.last_full_model != checkpoint_folder:
|
||||
@@ -320,6 +362,8 @@ def update_weights(
|
||||
target.weight.data = new_weight.cpu()
|
||||
target.to(device)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
if bnb is None:
|
||||
raise ImportError("bitsandbytes is required to merge 8-bit LoRA weights")
|
||||
target.weight.data = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
||||
)
|
||||
|
||||
@@ -240,8 +240,28 @@ class ReLoRAConfig(BaseModel):
|
||||
)
|
||||
relora_prune_ratio: float | None = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
json_schema_extra={
|
||||
"description": "threshold for optimizer magnitude when pruning"
|
||||
"description": (
|
||||
"Fraction of optimizer state values to zero on each ReLoRA restart. "
|
||||
"When relora_prune_method='reset' and this is omitted, defaults to "
|
||||
"0.999 (paper-style near-full reset). For other methods, defaults to 0.9."
|
||||
)
|
||||
},
|
||||
)
|
||||
relora_prune_method: Literal["magnitude", "random", "reset"] | None = Field(
|
||||
default="magnitude",
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Optimizer state pruning method on each ReLoRA restart. "
|
||||
"'magnitude' (default) keeps top-k by absolute value; "
|
||||
"'random' keeps a random subset at relora_prune_ratio; "
|
||||
"'reset' uses near-full random pruning (default ratio 0.999, "
|
||||
"honoring relora_prune_ratio when explicitly set). "
|
||||
"Paper-style recipe: relora_prune_method='reset' with no "
|
||||
"relora_prune_ratio, equivalent to 'random' with ratio=0.999."
|
||||
)
|
||||
},
|
||||
)
|
||||
relora_cpu_offload: bool | None = Field(
|
||||
|
||||
@@ -56,7 +56,72 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
],
|
||||
"warmup_steps": 10,
|
||||
"num_epochs": 2,
|
||||
"max_steps": 105, # at least 2x relora_steps
|
||||
"max_steps": 105, # at least 2x restart cadence
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
|
||||
"Relora model checkpoint not found"
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_relora_reset_method(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 50,
|
||||
"jagged_restart_warmup_steps": 10,
|
||||
"jagged_restart_anneal_steps": 10,
|
||||
"relora_prune_ratio": 0.5, # explicitly honored by reset (not ignored)
|
||||
"relora_prune_method": "reset",
|
||||
"relora_cpu_offload": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"warmup_steps": 10,
|
||||
"num_epochs": 2,
|
||||
"max_steps": 105,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
|
||||
186
tests/monkeypatch/test_relora.py
Normal file
186
tests/monkeypatch/test_relora.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Unit tests for axolotl.monkeypatch.relora.reset_optimizer."""
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from axolotl.monkeypatch.relora import (
|
||||
magnitude_pruning_,
|
||||
random_pruning_,
|
||||
reset_optimizer,
|
||||
)
|
||||
|
||||
ADAM_KEYS = ["exp_avg", "exp_avg_sq"]
|
||||
|
||||
|
||||
def _build_optimizer_with_state(seed: int = 0):
|
||||
"""Build a tiny optimizer over LoRA-shaped + non-LoRA params with populated state."""
|
||||
torch.manual_seed(seed)
|
||||
lora_a = nn.Parameter(torch.randn(8, 32))
|
||||
lora_b = nn.Parameter(torch.randn(32, 8))
|
||||
extra = nn.Parameter(torch.randn(64, 32))
|
||||
|
||||
optimizer = torch.optim.AdamW([lora_a, lora_b, extra], lr=1e-3)
|
||||
for _ in range(2):
|
||||
loss = (
|
||||
(lora_a * torch.randn_like(lora_a)).sum()
|
||||
+ (lora_b * torch.randn_like(lora_b)).sum()
|
||||
+ (extra * torch.randn_like(extra)).sum()
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
return optimizer, lora_a, lora_b, extra
|
||||
|
||||
|
||||
def test_reset_optimizer_only_touches_reset_params():
|
||||
"""State for params NOT in reset_params must be byte-identical after reset."""
|
||||
optimizer, lora_a, lora_b, extra = _build_optimizer_with_state()
|
||||
|
||||
extra_avg_before = optimizer.state[extra]["exp_avg"].clone()
|
||||
extra_avg_sq_before = optimizer.state[extra]["exp_avg_sq"].clone()
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
assert torch.equal(optimizer.state[extra]["exp_avg"], extra_avg_before)
|
||||
assert torch.equal(optimizer.state[extra]["exp_avg_sq"], extra_avg_sq_before)
|
||||
|
||||
|
||||
def test_reset_optimizer_actually_prunes_lora_state():
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
for param in (lora_a, lora_b):
|
||||
for key in ADAM_KEYS:
|
||||
zero_frac = (optimizer.state[param][key] == 0).float().mean().item()
|
||||
assert zero_frac >= 0.85
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method,ratio,expected_zero_frac",
|
||||
[
|
||||
("magnitude", 0.9, 0.9),
|
||||
("magnitude", 0.99, 0.99),
|
||||
("random", 0.9, 0.9),
|
||||
("random", 0.5, 0.5),
|
||||
# reset uses random pruning; relora_prune_ratio must be honored, not ignored.
|
||||
("reset", 0.9, 0.9),
|
||||
("reset", 0.5, 0.5),
|
||||
],
|
||||
)
|
||||
def test_prune_methods(method, ratio, expected_zero_frac):
|
||||
"""Each method zeros approximately the expected fraction."""
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state(seed=42)
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method=method,
|
||||
prune_ratio=ratio,
|
||||
)
|
||||
|
||||
total = 0
|
||||
zeros = 0
|
||||
for param in (lora_a, lora_b):
|
||||
for key in ADAM_KEYS:
|
||||
tensor = optimizer.state[param][key]
|
||||
total += tensor.numel()
|
||||
zeros += (tensor == 0).sum().item()
|
||||
|
||||
actual = zeros / total
|
||||
tolerance = 0.02 if method == "magnitude" else 0.05
|
||||
assert math.isclose(actual, expected_zero_frac, abs_tol=tolerance)
|
||||
|
||||
|
||||
def test_reset_optimizer_skips_keys_not_in_state_keys():
|
||||
"""Keys present in optimizer state but not in optimizer_state_keys are untouched."""
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
|
||||
exp_avg_sq_before = optimizer.state[lora_a]["exp_avg_sq"].clone()
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=["exp_avg"],
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
assert torch.equal(optimizer.state[lora_a]["exp_avg_sq"], exp_avg_sq_before)
|
||||
|
||||
|
||||
def test_reset_optimizer_handles_param_with_empty_state():
|
||||
"""Params with no optimizer state are skipped silently."""
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
orphan = nn.Parameter(torch.randn(4, 4))
|
||||
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b, orphan],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="magnitude",
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
assert orphan not in optimizer.state or not optimizer.state[orphan]
|
||||
|
||||
|
||||
def test_unknown_prune_method_raises():
|
||||
optimizer, lora_a, lora_b, _extra = _build_optimizer_with_state()
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown prune_method"):
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=[lora_a, lora_b],
|
||||
optimizer_state_keys=ADAM_KEYS,
|
||||
prune_method="bogus", # type: ignore[arg-type]
|
||||
prune_ratio=0.9,
|
||||
)
|
||||
|
||||
|
||||
def test_pruning_helpers_are_inplace():
|
||||
"""magnitude_pruning_ and random_pruning_ must mutate via tensor.mul_."""
|
||||
tensor = torch.randn(64)
|
||||
ptr_before = tensor.data_ptr()
|
||||
magnitude_pruning_(tensor, 0.5)
|
||||
assert tensor.data_ptr() == ptr_before
|
||||
|
||||
tensor = torch.randn(64)
|
||||
ptr_before = tensor.data_ptr()
|
||||
random_pruning_(tensor, 0.5)
|
||||
assert tensor.data_ptr() == ptr_before
|
||||
|
||||
|
||||
def test_pruning_helpers_support_uint8_tensors():
|
||||
"""Both pruning helpers must work on uint8 optimizer state tensors."""
|
||||
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
||||
magnitude_pruning_(tensor, 0.9)
|
||||
|
||||
assert tensor.dtype == torch.uint8
|
||||
magnitude_zero_frac = (tensor == 0).float().mean().item()
|
||||
assert 0.85 <= magnitude_zero_frac <= 0.95
|
||||
|
||||
tensor = torch.arange(1, 129, dtype=torch.uint8)
|
||||
with torch.random.fork_rng(devices=[]):
|
||||
torch.manual_seed(1234)
|
||||
random_pruning_(tensor, 0.9)
|
||||
|
||||
assert tensor.dtype == torch.uint8
|
||||
random_zero_frac = (tensor == 0).float().mean().item()
|
||||
assert 0.85 <= random_zero_frac <= 0.95
|
||||
Reference in New Issue
Block a user