fix: qwen3-next to use fla causal-conv1d to support packing (#3437
* fix: qwen3-next to use fla causal-conv1d to support packing * fix: causal import and update doc for v5 * fix: hard fail for packing without fla
This commit is contained in:
@@ -9,6 +9,11 @@ from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
||||
except ImportError:
|
||||
fla_causal_conv1d = None
|
||||
|
||||
|
||||
def get_cu_seqlens(position_ids):
|
||||
"""
|
||||
@@ -137,6 +142,11 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
and cache_position is not None
|
||||
)
|
||||
|
||||
# Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule
|
||||
cu_seqlens = None
|
||||
if not use_precomputed_states and position_ids is not None:
|
||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||
|
||||
# getting projected states from cache if it exists
|
||||
if cache_params is not None:
|
||||
conv_state = cache_params.conv_states[self.layer_idx]
|
||||
@@ -151,12 +161,11 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||
)
|
||||
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D]
|
||||
|
||||
if use_precomputed_states:
|
||||
# 2. Convolution sequence transformation
|
||||
# NOTE: the conv state is updated in `causal_conv1d_update`
|
||||
# Inference single-token path: causal_conv1d_update expects [B, D, T]
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = self.causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
@@ -164,24 +173,41 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
else:
|
||||
if cache_params is not None:
|
||||
# Cache state expects [B, D, T] for the inference update path
|
||||
mixed_qkv_t = mixed_qkv.transpose(1, 2)
|
||||
conv_state = F.pad(
|
||||
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
||||
mixed_qkv_t,
|
||||
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
|
||||
)
|
||||
cache_params.conv_states[self.layer_idx] = conv_state
|
||||
if self.causal_conv1d_fn is not None:
|
||||
mixed_qkv = self.causal_conv1d_fn(
|
||||
|
||||
if fla_causal_conv1d is not None:
|
||||
# FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support
|
||||
mixed_qkv, _ = fla_causal_conv1d(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=None,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
else:
|
||||
# PyTorch fallback (no cu_seqlens support)
|
||||
if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1:
|
||||
raise RuntimeError(
|
||||
"Packed sequences require fla.modules.convolution.causal_conv1d "
|
||||
"(cu_seqlens support). Install flash-linear-attention or disable packing."
|
||||
)
|
||||
LOG.warning_once(
|
||||
"FLA causal_conv1d not available. Falling back to PyTorch conv1d."
|
||||
)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
# mixed_qkv is [B, T, D] in all paths
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
@@ -203,7 +229,6 @@ def patch_qwen3_next_gateddelta_layer():
|
||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
|
||||
if not use_precomputed_states:
|
||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
|
||||
Reference in New Issue
Block a user