configurable weight scale normalization for MoE expert drift

This commit is contained in:
Wing Lian
2026-04-09 15:37:16 +00:00
parent 900eec7988
commit f608d263a6
4 changed files with 177 additions and 0 deletions

View File

@@ -137,6 +137,50 @@ This means the policy has diverged significantly from the weights used by vLLM f
- Increase `gradient_accumulation_steps` to smooth out noisy batches.
- Check for NaN issues (see next section).
## MoE Weight Scale Drift
**Symptom**: Model works on short prompts but loses coherence on long conversations — repeating itself, "philosophizing", or generating broken code. Particularly affects MoE models with recurrent/SSM components (e.g. DeltaNet linear attention).
**Root cause**: In MoE models trained with AdamW, rarely-activated experts accumulate smaller second-moment estimates. This gives them a disproportionately large effective learning rate, causing their weights to drift to higher variance than the group norm. In recurrent components like `conv1d` in DeltaNet layers, this amplifies short-range context and washes out long-range state.
**Detection**: Use `normalize_weight_scales` with `dry_run: true` to scan for anomalies without modifying weights:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
dry_run: true
```
This logs any tensors matching the pattern whose standard deviation exceeds 1.3x the group median. Example output:
```
normalize_weight_scales [DRY RUN]: pattern 'linear_attn\.conv1d\.weight' —
3/30 tensors outside 1.3x threshold (median std=0.062733):
layers.36.linear_attn.conv1d.weight: std=0.101870 (1.62x median)
layers.37.linear_attn.conv1d.weight: std=0.102362 (1.63x median)
layers.38.linear_attn.conv1d.weight: std=0.089227 (1.42x median)
```
Each rule accepts:
- `name_pattern`: regex matched against parameter names. All matching tensors form a group.
- `threshold`: flag tensors whose std deviates from the group median by more than this factor (default: 1.5).
- `dry_run`: when `true`, log anomalies without modifying weights (default: `false`).
Multiple rules can target different tensor patterns:
```yaml
normalize_weight_scales:
- name_pattern: 'linear_attn\.conv1d\.weight'
threshold: 1.3
- name_pattern: 'experts\.gate_up_proj'
threshold: 1.5
dry_run: true # just check these, don't fix
```
The transform runs after model loading but before adapter injection, so it modifies the base model weights directly.
## NaN and Inf Handling
### Common Causes