Compare commits
2 Commits
08fc7de87e
...
weight-sca
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f608d263a6 | ||
|
|
900eec7988 |
@@ -137,6 +137,50 @@ This means the policy has diverged significantly from the weights used by vLLM f
|
||||
- Increase `gradient_accumulation_steps` to smooth out noisy batches.
|
||||
- Check for NaN issues (see next section).
|
||||
|
||||
## MoE Weight Scale Drift
|
||||
|
||||
**Symptom**: Model works on short prompts but loses coherence on long conversations — repeating itself, "philosophizing", or generating broken code. Particularly affects MoE models with recurrent/SSM components (e.g. DeltaNet linear attention).
|
||||
|
||||
**Root cause**: In MoE models trained with AdamW, rarely-activated experts accumulate smaller second-moment estimates. This gives them a disproportionately large effective learning rate, causing their weights to drift to higher variance than the group norm. In recurrent components like `conv1d` in DeltaNet layers, this amplifies short-range context and washes out long-range state.
|
||||
|
||||
**Detection**: Use `normalize_weight_scales` with `dry_run: true` to scan for anomalies without modifying weights:
|
||||
|
||||
```yaml
|
||||
normalize_weight_scales:
|
||||
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||
threshold: 1.3
|
||||
dry_run: true
|
||||
```
|
||||
|
||||
This logs any tensors matching the pattern whose standard deviation exceeds 1.3x the group median. Example output:
|
||||
|
||||
```
|
||||
normalize_weight_scales [DRY RUN]: pattern 'linear_attn\.conv1d\.weight' —
|
||||
3/30 tensors outside 1.3x threshold (median std=0.062733):
|
||||
layers.36.linear_attn.conv1d.weight: std=0.101870 (1.62x median)
|
||||
layers.37.linear_attn.conv1d.weight: std=0.102362 (1.63x median)
|
||||
layers.38.linear_attn.conv1d.weight: std=0.089227 (1.42x median)
|
||||
```
|
||||
|
||||
Each rule accepts:
|
||||
|
||||
- `name_pattern`: regex matched against parameter names. All matching tensors form a group.
|
||||
- `threshold`: flag tensors whose std deviates from the group median by more than this factor (default: 1.5).
|
||||
- `dry_run`: when `true`, log anomalies without modifying weights (default: `false`).
|
||||
|
||||
Multiple rules can target different tensor patterns:
|
||||
|
||||
```yaml
|
||||
normalize_weight_scales:
|
||||
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||
threshold: 1.3
|
||||
- name_pattern: 'experts\.gate_up_proj'
|
||||
threshold: 1.5
|
||||
dry_run: true # just check these, don't fix
|
||||
```
|
||||
|
||||
The transform runs after model loading but before adapter injection, so it modifies the base model weights directly.
|
||||
|
||||
## NaN and Inf Handling
|
||||
|
||||
### Common Causes
|
||||
|
||||
@@ -160,29 +160,16 @@ class TelemetryManager:
|
||||
if not is_main_process():
|
||||
return False
|
||||
|
||||
# Parse relevant env vars
|
||||
axolotl_do_not_track = os.getenv("AXOLOTL_DO_NOT_TRACK")
|
||||
do_not_track = os.getenv("DO_NOT_TRACK")
|
||||
def is_truthy_env(var_name: str) -> bool:
|
||||
value = os.getenv(var_name)
|
||||
if value is None:
|
||||
return False
|
||||
return value.strip().lower() in ("1", "true")
|
||||
|
||||
# Default to enabled (opt-out model)
|
||||
if axolotl_do_not_track is None or axolotl_do_not_track.lower() not in (
|
||||
"0",
|
||||
"1",
|
||||
"false",
|
||||
"true",
|
||||
):
|
||||
return True
|
||||
|
||||
if do_not_track is None:
|
||||
do_not_track = "0"
|
||||
|
||||
# Respect AXOLOTL_DO_NOT_TRACK, DO_NOT_TRACK if enabled
|
||||
enabled = axolotl_do_not_track.lower() not in (
|
||||
"1",
|
||||
"true",
|
||||
) and do_not_track.lower() not in ("1", "true")
|
||||
|
||||
return enabled
|
||||
# Telemetry is enabled by default unless either opt-out var is set
|
||||
return not (
|
||||
is_truthy_env("AXOLOTL_DO_NOT_TRACK") or is_truthy_env("DO_NOT_TRACK")
|
||||
)
|
||||
|
||||
def _load_whitelist(self) -> dict:
|
||||
"""Load HuggingFace Hub organization whitelist"""
|
||||
|
||||
@@ -38,6 +38,7 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.normalize_weights import normalize_weight_scales
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -105,6 +106,10 @@ def setup_model_and_tokenizer(
|
||||
event_type="peft-config-load", properties=peft_config.to_dict()
|
||||
)
|
||||
|
||||
# Normalize weight scales for MoE/hybrid models with drifted expert weights
|
||||
if cfg.normalize_weight_scales:
|
||||
normalize_weight_scales(model, cfg.normalize_weight_scales)
|
||||
|
||||
# Apply freezing if specified
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_layers_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
115
src/axolotl/utils/normalize_weights.py
Normal file
115
src/axolotl/utils/normalize_weights.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Detect and fix weight scale anomalies in MoE/hybrid models.
|
||||
|
||||
In MoE models trained with AdamW, rarely-activated experts accumulate smaller
|
||||
second-moment estimates, giving them a disproportionately large effective
|
||||
learning rate. Over time this causes their weights to drift to higher
|
||||
variance than the median for the same tensor across layers.
|
||||
|
||||
For recurrent / SSM / DeltaNet components (e.g. ``conv1d.weight`` in linear
|
||||
attention layers), this drift corrupts the hidden state and degrades long-
|
||||
context performance — the model "forgets" after a few tokens.
|
||||
|
||||
This module provides a configurable transform that detects outlier weight
|
||||
scales per-tensor-pattern and rescales them to the group median.
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def normalize_weight_scales(model, rules):
|
||||
"""Normalize weight scales for tensor groups that have outlier variance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : torch.nn.Module
|
||||
The loaded model (before adapter injection).
|
||||
rules : list[dict]
|
||||
Each rule is a dict with keys:
|
||||
|
||||
- ``name_pattern`` (str): regex matched against each named parameter.
|
||||
Parameters that match are grouped together, and outliers within the
|
||||
group are rescaled.
|
||||
- ``threshold`` (float, default 1.5): a parameter is flagged when its
|
||||
std deviates from the group median by more than this factor
|
||||
(ratio > threshold or ratio < 1/threshold).
|
||||
- ``dry_run`` (bool, default False): when True, log anomalies but do
|
||||
not modify weights.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of tensors that were rescaled (0 in dry-run mode).
|
||||
"""
|
||||
total_fixed = 0
|
||||
|
||||
for rule in rules:
|
||||
pattern = rule.get("name_pattern")
|
||||
if not pattern:
|
||||
LOG.warning("normalize_weight_scales: rule missing 'name_pattern', skipping")
|
||||
continue
|
||||
|
||||
threshold = float(rule.get("threshold", 1.5))
|
||||
dry_run = bool(rule.get("dry_run", False))
|
||||
regex = re.compile(pattern)
|
||||
|
||||
# Collect matching tensors
|
||||
matches = []
|
||||
for name, param in model.named_parameters():
|
||||
if regex.search(name):
|
||||
with torch.no_grad():
|
||||
std = param.data.float().std().item()
|
||||
matches.append((name, param, std))
|
||||
|
||||
if len(matches) < 3:
|
||||
if is_main_process():
|
||||
LOG.info(
|
||||
f"normalize_weight_scales: pattern '{pattern}' matched "
|
||||
f"{len(matches)} tensors (need >=3 to detect outliers), skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Compute group median std
|
||||
stds = [s for _, _, s in matches]
|
||||
median_std = float(sorted(stds)[len(stds) // 2])
|
||||
|
||||
if median_std < 1e-10:
|
||||
continue
|
||||
|
||||
# Detect and fix outliers
|
||||
outliers = []
|
||||
for name, param, std in matches:
|
||||
ratio = std / median_std
|
||||
if ratio > threshold or ratio < (1.0 / threshold):
|
||||
outliers.append((name, param, std, ratio))
|
||||
if not dry_run:
|
||||
scale_factor = median_std / std
|
||||
param.data.mul_(scale_factor)
|
||||
total_fixed += 1
|
||||
|
||||
# Report
|
||||
if is_main_process() and outliers:
|
||||
mode = "DRY RUN" if dry_run else "FIXED"
|
||||
LOG.warning(
|
||||
f"normalize_weight_scales [{mode}]: pattern '{pattern}' — "
|
||||
f"{len(outliers)}/{len(matches)} tensors outside "
|
||||
f"{threshold:.1f}x threshold (median std={median_std:.6f}):"
|
||||
)
|
||||
for name, _, std, ratio in outliers:
|
||||
LOG.warning(f" {name}: std={std:.6f} ({ratio:.2f}x median)")
|
||||
elif is_main_process():
|
||||
LOG.info(
|
||||
f"normalize_weight_scales: pattern '{pattern}' — "
|
||||
f"{len(matches)} tensors, all within {threshold:.1f}x threshold "
|
||||
f"(median std={median_std:.6f})"
|
||||
)
|
||||
|
||||
return total_fixed
|
||||
@@ -578,6 +578,19 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
normalize_weight_scales: list[dict] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Detect and rescale outlier weight tensors caused by AdamW + rare "
|
||||
"MoE expert drift. Each entry is a rule with: "
|
||||
"'name_pattern' (regex matching parameter names to group), "
|
||||
"'threshold' (float, default 1.5 — flag tensors whose std deviates from the "
|
||||
"group median by more than this factor), "
|
||||
"'dry_run' (bool, default false — log anomalies without modifying weights). "
|
||||
"Example: [{name_pattern: 'linear_attn\\.conv1d\\.weight', threshold: 1.3}]"
|
||||
},
|
||||
)
|
||||
|
||||
unfrozen_parameters: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -65,47 +65,57 @@ def test_singleton_instance(telemetry_manager_class):
|
||||
assert telemetry_manager_class.get_instance() is first
|
||||
|
||||
|
||||
def test_telemetry_enabled_by_default(telemetry_manager_class):
|
||||
"""Test that telemetry is enabled by default (opt-out)"""
|
||||
with (
|
||||
patch.dict(os.environ, {"RANK": "0"}, clear=True),
|
||||
patch("time.sleep"),
|
||||
patch("logging.Logger.info"),
|
||||
class TestTelemetryOptOut:
|
||||
"""
|
||||
Telemetry is opt-out: enabled by default, disabled by AXOLOTL_DO_NOT_TRACK
|
||||
or DO_NOT_TRACK. Each env var is checked independently — setting either one
|
||||
to a truthy value ("1" or "true") disables telemetry.
|
||||
|
||||
The parametrized table below is the source of truth for expected behavior.
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
# AXOLOTL_DO_NOT_TRACK DO_NOT_TRACK expected
|
||||
@pytest.mark.parametrize("axolotl_dnt, dnt, expected", [
|
||||
# --- Neither var set: telemetry ON ---
|
||||
(None, None, True),
|
||||
|
||||
# --- Only AXOLOTL_DO_NOT_TRACK set ---
|
||||
("0", None, True), # explicit opt-in
|
||||
("false", None, True), # explicit opt-in
|
||||
("1", None, False), # opt-out
|
||||
("true", None, False), # opt-out
|
||||
(" 1 ", None, False), # whitespace-padded opt-out
|
||||
|
||||
# --- Only DO_NOT_TRACK set (was broken before fix) ---
|
||||
(None, "0", True), # explicit opt-in
|
||||
(None, "false", True), # explicit opt-in
|
||||
(None, "1", False), # opt-out
|
||||
(None, "true", False), # opt-out
|
||||
|
||||
# --- Both set: either truthy → disabled ---
|
||||
("0", "1", False), # DO_NOT_TRACK wins
|
||||
("1", "0", False), # AXOLOTL_DO_NOT_TRACK wins
|
||||
("1", "1", False), # both opt-out
|
||||
("0", "0", True), # both opt-in
|
||||
])
|
||||
# fmt: on
|
||||
def test_do_not_track_env_vars(
|
||||
self, telemetry_manager_class, axolotl_dnt, dnt, expected
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
env = {"RANK": "0"}
|
||||
if axolotl_dnt is not None:
|
||||
env["AXOLOTL_DO_NOT_TRACK"] = axolotl_dnt
|
||||
if dnt is not None:
|
||||
env["DO_NOT_TRACK"] = dnt
|
||||
|
||||
|
||||
def test_telemetry_enabled_with_explicit_opt_in(telemetry_manager_class):
|
||||
"""Test that telemetry is enabled when AXOLOTL_DO_NOT_TRACK=0"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_axolotl_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when AXOLOTL_DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(os.environ, {"AXOLOTL_DO_NOT_TRACK": "1", "RANK": "0"}),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
|
||||
|
||||
def test_telemetry_disabled_with_do_not_track(telemetry_manager_class):
|
||||
"""Test that telemetry is disabled when DO_NOT_TRACK=1"""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {"AXOLOTL_DO_NOT_TRACK": "0", "DO_NOT_TRACK": "1", "RANK": "0"}
|
||||
),
|
||||
patch("time.sleep"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert not manager.enabled
|
||||
with (
|
||||
patch.dict(os.environ, env, clear=True),
|
||||
patch("time.sleep"),
|
||||
patch("logging.Logger.info"),
|
||||
):
|
||||
manager = telemetry_manager_class()
|
||||
assert manager.enabled is expected
|
||||
|
||||
|
||||
def test_telemetry_disabled_for_non_main_process(telemetry_manager_class):
|
||||
|
||||
Reference in New Issue
Block a user