- Kernel: fused_rms_norm_rope crashed when cos.shape[-1] < x.shape[-1]. Triton forward/backward take an n_rot runtime arg that restricts rotate_half to [0, n_rot) and treats trailing cols as RMSNorm-only pass-through (cos=1, sin=0 defaults). Wrapper also expands cos/sin that broadcast over batch. - Forward: _make_fused_forward used a stale shared_kv_states kwarg the current decoder layer no longer passes. Now mirrors stock attention, reading/writing past_key_values.shared_layers.
8.5 KiB
8.5 KiB