- Kernel: fused_rms_norm_rope crashed when cos.shape[-1] < x.shape[-1].
Triton forward/backward take an n_rot runtime arg that restricts
rotate_half to [0, n_rot) and treats trailing cols as RMSNorm-only
pass-through (cos=1, sin=0 defaults). Wrapper also expands cos/sin
that broadcast over batch.
- Forward: _make_fused_forward used a stale shared_kv_states kwarg the
current decoder layer no longer passes. Now mirrors stock attention,
reading/writing past_key_values.shared_layers.