diff --git a/docs/training_stability.qmd b/docs/training_stability.qmd index e2cd79f89..41a62e010 100644 --- a/docs/training_stability.qmd +++ b/docs/training_stability.qmd @@ -137,6 +137,50 @@ This means the policy has diverged significantly from the weights used by vLLM f - Increase `gradient_accumulation_steps` to smooth out noisy batches. - Check for NaN issues (see next section). +## MoE Weight Scale Drift + +**Symptom**: Model works on short prompts but loses coherence on long conversations — repeating itself, "philosophizing", or generating broken code. Particularly affects MoE models with recurrent/SSM components (e.g. DeltaNet linear attention). + +**Root cause**: In MoE models trained with AdamW, rarely-activated experts accumulate smaller second-moment estimates. This gives them a disproportionately large effective learning rate, causing their weights to drift to higher variance than the group norm. In recurrent components like `conv1d` in DeltaNet layers, this amplifies short-range context and washes out long-range state. + +**Detection**: Use `normalize_weight_scales` with `dry_run: true` to scan for anomalies without modifying weights: + +```yaml +normalize_weight_scales: + - name_pattern: 'linear_attn\.conv1d\.weight' + threshold: 1.3 + dry_run: true +``` + +This logs any tensors matching the pattern whose standard deviation exceeds 1.3x the group median. Example output: + +``` +normalize_weight_scales [DRY RUN]: pattern 'linear_attn\.conv1d\.weight' — + 3/30 tensors outside 1.3x threshold (median std=0.062733): + layers.36.linear_attn.conv1d.weight: std=0.101870 (1.62x median) + layers.37.linear_attn.conv1d.weight: std=0.102362 (1.63x median) + layers.38.linear_attn.conv1d.weight: std=0.089227 (1.42x median) +``` + +Each rule accepts: + +- `name_pattern`: regex matched against parameter names. All matching tensors form a group. +- `threshold`: flag tensors whose std deviates from the group median by more than this factor (default: 1.5). +- `dry_run`: when `true`, log anomalies without modifying weights (default: `false`). + +Multiple rules can target different tensor patterns: + +```yaml +normalize_weight_scales: + - name_pattern: 'linear_attn\.conv1d\.weight' + threshold: 1.3 + - name_pattern: 'experts\.gate_up_proj' + threshold: 1.5 + dry_run: true # just check these, don't fix +``` + +The transform runs after model loading but before adapter injection, so it modifies the base model weights directly. + ## NaN and Inf Handling ### Common Causes diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 774aa1cec..27d494d68 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -38,6 +38,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.logging import get_logger +from axolotl.utils.normalize_weights import normalize_weight_scales from axolotl.utils.schemas.enums import RLType from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.trainer import setup_trainer @@ -105,6 +106,10 @@ def setup_model_and_tokenizer( event_type="peft-config-load", properties=peft_config.to_dict() ) + # Normalize weight scales for MoE/hybrid models with drifted expert weights + if cfg.normalize_weight_scales: + normalize_weight_scales(model, cfg.normalize_weight_scales) + # Apply freezing if specified if cfg.unfrozen_parameters: freeze_layers_except(model, cfg.unfrozen_parameters) diff --git a/src/axolotl/utils/normalize_weights.py b/src/axolotl/utils/normalize_weights.py new file mode 100644 index 000000000..06b1393c9 --- /dev/null +++ b/src/axolotl/utils/normalize_weights.py @@ -0,0 +1,115 @@ +""" +Detect and fix weight scale anomalies in MoE/hybrid models. + +In MoE models trained with AdamW, rarely-activated experts accumulate smaller +second-moment estimates, giving them a disproportionately large effective +learning rate. Over time this causes their weights to drift to higher +variance than the median for the same tensor across layers. + +For recurrent / SSM / DeltaNet components (e.g. ``conv1d.weight`` in linear +attention layers), this drift corrupts the hidden state and degrades long- +context performance — the model "forgets" after a few tokens. + +This module provides a configurable transform that detects outlier weight +scales per-tensor-pattern and rescales them to the group median. +""" + +import re +from collections import defaultdict + +import torch + +from axolotl.utils.distributed import is_main_process +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def normalize_weight_scales(model, rules): + """Normalize weight scales for tensor groups that have outlier variance. + + Parameters + ---------- + model : torch.nn.Module + The loaded model (before adapter injection). + rules : list[dict] + Each rule is a dict with keys: + + - ``name_pattern`` (str): regex matched against each named parameter. + Parameters that match are grouped together, and outliers within the + group are rescaled. + - ``threshold`` (float, default 1.5): a parameter is flagged when its + std deviates from the group median by more than this factor + (ratio > threshold or ratio < 1/threshold). + - ``dry_run`` (bool, default False): when True, log anomalies but do + not modify weights. + + Returns + ------- + int + Number of tensors that were rescaled (0 in dry-run mode). + """ + total_fixed = 0 + + for rule in rules: + pattern = rule.get("name_pattern") + if not pattern: + LOG.warning("normalize_weight_scales: rule missing 'name_pattern', skipping") + continue + + threshold = float(rule.get("threshold", 1.5)) + dry_run = bool(rule.get("dry_run", False)) + regex = re.compile(pattern) + + # Collect matching tensors + matches = [] + for name, param in model.named_parameters(): + if regex.search(name): + with torch.no_grad(): + std = param.data.float().std().item() + matches.append((name, param, std)) + + if len(matches) < 3: + if is_main_process(): + LOG.info( + f"normalize_weight_scales: pattern '{pattern}' matched " + f"{len(matches)} tensors (need >=3 to detect outliers), skipping" + ) + continue + + # Compute group median std + stds = [s for _, _, s in matches] + median_std = float(sorted(stds)[len(stds) // 2]) + + if median_std < 1e-10: + continue + + # Detect and fix outliers + outliers = [] + for name, param, std in matches: + ratio = std / median_std + if ratio > threshold or ratio < (1.0 / threshold): + outliers.append((name, param, std, ratio)) + if not dry_run: + scale_factor = median_std / std + param.data.mul_(scale_factor) + total_fixed += 1 + + # Report + if is_main_process() and outliers: + mode = "DRY RUN" if dry_run else "FIXED" + LOG.warning( + f"normalize_weight_scales [{mode}]: pattern '{pattern}' — " + f"{len(outliers)}/{len(matches)} tensors outside " + f"{threshold:.1f}x threshold (median std={median_std:.6f}):" + ) + for name, _, std, ratio in outliers: + LOG.warning(f" {name}: std={std:.6f} ({ratio:.2f}x median)") + elif is_main_process(): + LOG.info( + f"normalize_weight_scales: pattern '{pattern}' — " + f"{len(matches)} tensors, all within {threshold:.1f}x threshold " + f"(median std={median_std:.6f})" + ) + + return total_fixed diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d0f588d9b..4c42c8323 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -578,6 +578,19 @@ class AxolotlInputConfig( }, ) + normalize_weight_scales: list[dict] | None = Field( + default=None, + json_schema_extra={ + "description": "Detect and rescale outlier weight tensors caused by AdamW + rare " + "MoE expert drift. Each entry is a rule with: " + "'name_pattern' (regex matching parameter names to group), " + "'threshold' (float, default 1.5 — flag tensors whose std deviates from the " + "group median by more than this factor), " + "'dry_run' (bool, default false — log anomalies without modifying weights). " + "Example: [{name_pattern: 'linear_attn\\.conv1d\\.weight', threshold: 1.3}]" + }, + ) + unfrozen_parameters: list[str] | None = Field( default=None, json_schema_extra={