Feat: add qwen3-next (w packing+cce) (#3150)

* feat: upgrade cce for qwen3-next

* feat: add sample qwen3 config

* feat: add packing patch for chunk_gated_delta_rule

* feat: add qwen3 link

* fix: tuple name

* feat: add tested qwen3 config

* fix: improve log

* feat: add patch for fla without packing

* fix: remove fla patch for standard mode

* feat: enable packing

* feat: add qwen3-next tests

* chore: move tests
This commit is contained in:
NanoCode012
2025-09-23 11:31:15 +07:00
committed by GitHub
parent 7be8740c5c
commit 08d831c3d5
11 changed files with 566 additions and 4 deletions

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
]
},
{

View File

@@ -0,0 +1,64 @@
# Finetune Qwen3-Next with Axolotl
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
```
4. Run the finetuning example:
```bash
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about 41.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -0,0 +1,60 @@
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
```
## Usage
@@ -65,6 +65,7 @@ plugins:
- qwen2_5_vl
- qwen3
- qwen3_moe
- qwen3_next
- smollm3
- seed_oss
- voxtral

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
)

View File

@@ -169,6 +169,13 @@ class PatchManager:
patch_llama4_linearized_modeling()
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,

View File

@@ -0,0 +1 @@
"""Qwen3_Next model monkeypatches."""

View File

@@ -0,0 +1,317 @@
"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention."""
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def get_cu_seqlens(position_ids):
"""
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
"""
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
position_ids = position_ids.view(-1)
indices_q = (position_ids == 0).nonzero().view(-1)
cu_seq_lens_q = torch.cat(
(
indices_q.to(**tensor_kwargs),
torch.tensor(position_ids.size(), **tensor_kwargs),
)
)
return cu_seq_lens_q
def patch_qwen3_next_decoder_layer():
"""Patch Qwen3NextDecoderLayer to pass position_ids to linear attention."""
try:
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDecoderLayer,
)
except ImportError:
LOG.warning("Qwen3Next model not found, skipping patch")
return
# Store original forward method
original_decoder_forward = Qwen3NextDecoderLayer.forward
def patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Token Mixer
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids,
)
elif self.layer_type == "full_attention":
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# For the MoE layers, we need to unpack
if isinstance(hidden_states, Tuple):
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
# Apply the patches
Qwen3NextDecoderLayer.forward = patched_decoder_forward
def unpatch():
"""Restore the original forward method"""
Qwen3NextDecoderLayer.forward = original_decoder_forward
return unpatch
def patch_qwen3_next_gateddelta_layer():
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
try:
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextDynamicCache,
Qwen3NextGatedDeltaNet,
apply_mask_to_padding_states,
)
except ImportError:
LOG.warning("Qwen3Next model not found, skipping patch")
return
# Store original forward method
original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward
def patched_gated_delta_net_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Qwen3NextDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
query, key, value = (
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
)
mixed_qkv = torch.cat((query, key, value), dim=-1)
mixed_qkv = mixed_qkv.transpose(1, 2)
if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
else:
if cache_params is not None:
conv_state = F.pad(
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
cu_seqlens=cu_seqlens,
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
# Update cache
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = core_attn_out.reshape(
core_attn_out.shape[0], core_attn_out.shape[1], -1
)
output = self.out_proj(core_attn_out)
return output
# Apply the patches
Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward
def unpatch():
"""Restore the original forward method"""
Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward
return unpatch
def patch_qwen3_next_imports():
"""Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available."""
try:
import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling
except ImportError:
LOG.warning("Qwen3Next model not found, skipping import patch")
return
# Save original values for unpatch
original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None)
original_chunk_gated_delta_rule = getattr(
qwen3_modeling, "chunk_gated_delta_rule", None
)
original_fused_recurrent_gated_delta_rule = getattr(
qwen3_modeling, "fused_recurrent_gated_delta_rule", None
)
original_is_fast_path_available = getattr(
qwen3_modeling, "is_fast_path_available", False
)
try:
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import (
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated
qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule
qwen3_modeling.fused_recurrent_gated_delta_rule = (
fused_recurrent_gated_delta_rule
)
# Force is_fast_path_available to be True
# fla has triton kernels for causal_conv1d
qwen3_modeling.is_fast_path_available = True
except ImportError:
qwen3_modeling.chunk_gated_delta_rule = None
qwen3_modeling.fused_recurrent_gated_delta_rule = None
qwen3_modeling.FusedRMSNormGated = None
def unpatch():
"""Restore the original import values"""
qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated
qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule
qwen3_modeling.fused_recurrent_gated_delta_rule = (
original_fused_recurrent_gated_delta_rule
)
qwen3_modeling.is_fast_path_available = original_is_fast_path_available
return unpatch
def patch_qwen3_next_modeling_packing():
"""Apply all Qwen3Next model patches."""
patch_qwen3_next_imports()
patch_qwen3_next_decoder_layer()
patch_qwen3_next_gateddelta_layer()
LOG.info("Applied Qwen3Next patch for packing")

View File

@@ -21,6 +21,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"qwen2_moe",
"qwen3",
"qwen3_moe",
"qwen3_next",
"falcon",
"phi",
"phi3",

View File

@@ -0,0 +1,111 @@
"""Integration tests for Qwen3 Next modeling patches."""
import pytest
import torch
# Skip entire module if qwen3_next not available
qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next")
class TestQwen3NextModelingPatchIntegration:
"""Test Qwen3 Next modeling patch integration."""
@pytest.mark.integration
def test_qwen3_next_decoder_layer_patch(self):
"""Test that Qwen3Next decoder layer patch can be applied."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_decoder_layer,
)
# Store original method
original_forward = qwen3_next.Qwen3NextDecoderLayer.forward
# Apply patch and get unpatch function
unpatch_fn = patch_qwen3_next_decoder_layer()
# Verify patch was applied
assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (
"decoder layer forward method was not patched"
)
# Verify the method is still callable
assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (
"Patched method is not callable"
)
# Test unpatch function
if unpatch_fn:
unpatch_fn()
assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (
"unpatch function did not restore original method"
)
@pytest.mark.integration
def test_qwen3_next_gateddelta_layer_patch(self):
"""Test that Qwen3Next GatedDeltaNet patch can be applied."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_gateddelta_layer,
)
# Store original method
original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward
# Apply patch and get unpatch function
unpatch_fn = patch_qwen3_next_gateddelta_layer()
# Verify patch was applied
assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (
"GatedDeltaNet forward method was not patched"
)
# Verify the method is still callable
assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (
"Patched method is not callable"
)
# Test unpatch function
if unpatch_fn:
unpatch_fn()
assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (
"unpatch function did not restore original method"
)
@pytest.mark.integration
def test_qwen3_next_imports_patch(self):
"""Test that Qwen3Next imports patch can be applied without errors."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_imports,
)
# Apply patch - should not raise any exceptions even if modules unavailable
unpatch_fn = patch_qwen3_next_imports()
# Test that unpatch function is returned (or None if skipped)
assert unpatch_fn is None or callable(unpatch_fn), (
"patch_qwen3_next_imports should return None or callable unpatch function"
)
@pytest.mark.integration
def test_qwen3_next_modeling_packing_patch(self):
"""Test that all Qwen3Next modeling patches can be applied together."""
from axolotl.monkeypatch.models.qwen3_next.modeling import (
patch_qwen3_next_modeling_packing,
)
# This should not raise any exceptions
patch_qwen3_next_modeling_packing()
@pytest.mark.integration
def test_get_cu_seqlens_utility():
"""Test the get_cu_seqlens utility function."""
from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens
# Test with simple position_ids
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
cu_seqlens = get_cu_seqlens(position_ids)
assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype"
# Should return tensor with start positions and total length
expected = torch.tensor([0, 3, 5], dtype=torch.int32)
assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"