configurable weight scale normalization for MoE expert drift
This commit is contained in:
@@ -137,6 +137,50 @@ This means the policy has diverged significantly from the weights used by vLLM f
|
||||
- Increase `gradient_accumulation_steps` to smooth out noisy batches.
|
||||
- Check for NaN issues (see next section).
|
||||
|
||||
## MoE Weight Scale Drift
|
||||
|
||||
**Symptom**: Model works on short prompts but loses coherence on long conversations — repeating itself, "philosophizing", or generating broken code. Particularly affects MoE models with recurrent/SSM components (e.g. DeltaNet linear attention).
|
||||
|
||||
**Root cause**: In MoE models trained with AdamW, rarely-activated experts accumulate smaller second-moment estimates. This gives them a disproportionately large effective learning rate, causing their weights to drift to higher variance than the group norm. In recurrent components like `conv1d` in DeltaNet layers, this amplifies short-range context and washes out long-range state.
|
||||
|
||||
**Detection**: Use `normalize_weight_scales` with `dry_run: true` to scan for anomalies without modifying weights:
|
||||
|
||||
```yaml
|
||||
normalize_weight_scales:
|
||||
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||
threshold: 1.3
|
||||
dry_run: true
|
||||
```
|
||||
|
||||
This logs any tensors matching the pattern whose standard deviation exceeds 1.3x the group median. Example output:
|
||||
|
||||
```
|
||||
normalize_weight_scales [DRY RUN]: pattern 'linear_attn\.conv1d\.weight' —
|
||||
3/30 tensors outside 1.3x threshold (median std=0.062733):
|
||||
layers.36.linear_attn.conv1d.weight: std=0.101870 (1.62x median)
|
||||
layers.37.linear_attn.conv1d.weight: std=0.102362 (1.63x median)
|
||||
layers.38.linear_attn.conv1d.weight: std=0.089227 (1.42x median)
|
||||
```
|
||||
|
||||
Each rule accepts:
|
||||
|
||||
- `name_pattern`: regex matched against parameter names. All matching tensors form a group.
|
||||
- `threshold`: flag tensors whose std deviates from the group median by more than this factor (default: 1.5).
|
||||
- `dry_run`: when `true`, log anomalies without modifying weights (default: `false`).
|
||||
|
||||
Multiple rules can target different tensor patterns:
|
||||
|
||||
```yaml
|
||||
normalize_weight_scales:
|
||||
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||
threshold: 1.3
|
||||
- name_pattern: 'experts\.gate_up_proj'
|
||||
threshold: 1.5
|
||||
dry_run: true # just check these, don't fix
|
||||
```
|
||||
|
||||
The transform runs after model loading but before adapter injection, so it modifies the base model weights directly.
|
||||
|
||||
## NaN and Inf Handling
|
||||
|
||||
### Common Causes
|
||||
|
||||
Reference in New Issue
Block a user