Feat: add qwen3-next (w packing+cce) (#3150)
* feat: upgrade cce for qwen3-next * feat: add sample qwen3 config * feat: add packing patch for chunk_gated_delta_rule * feat: add qwen3 link * fix: tuple name * feat: add tested qwen3 config * fix: improve log * feat: add patch for fla without packing * fix: remove fla patch for standard mode * feat: enable packing * feat: add qwen3-next tests * chore: move tests
This commit is contained in:
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
64
examples/qwen3-next/README.md
Normal file
64
examples/qwen3-next/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# Finetune Qwen3-Next with Axolotl
|
||||
|
||||
[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 41.7 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
60
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
60
examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Normal file
@@ -0,0 +1,60 @@
|
||||
base_model: Qwen/Qwen3-Next-80B-A3B-Instruct
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"'
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
- If you are installing from pip
|
||||
```bash
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"
|
||||
```
|
||||
|
||||
## Usage
|
||||
@@ -65,6 +65,7 @@ plugins:
|
||||
- qwen2_5_vl
|
||||
- qwen3
|
||||
- qwen3_moe
|
||||
- qwen3_next
|
||||
- smollm3
|
||||
- seed_oss
|
||||
- voxtral
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -169,6 +169,13 @@ class PatchManager:
|
||||
|
||||
patch_llama4_linearized_modeling()
|
||||
|
||||
if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_modeling_packing,
|
||||
)
|
||||
|
||||
patch_qwen3_next_modeling_packing()
|
||||
|
||||
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
|
||||
1
src/axolotl/monkeypatch/models/qwen3_next/__init__.py
Normal file
1
src/axolotl/monkeypatch/models/qwen3_next/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Qwen3_Next model monkeypatches."""
|
||||
317
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Normal file
317
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def get_cu_seqlens(position_ids):
|
||||
"""
|
||||
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
|
||||
|
||||
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
|
||||
"""
|
||||
tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
|
||||
|
||||
position_ids = position_ids.view(-1)
|
||||
indices_q = (position_ids == 0).nonzero().view(-1)
|
||||
|
||||
cu_seq_lens_q = torch.cat(
|
||||
(
|
||||
indices_q.to(**tensor_kwargs),
|
||||
torch.tensor(position_ids.size(), **tensor_kwargs),
|
||||
)
|
||||
)
|
||||
|
||||
return cu_seq_lens_q
|
||||
|
||||
|
||||
def patch_qwen3_next_decoder_layer():
|
||||
"""Patch Qwen3NextDecoderLayer to pass position_ids to linear attention."""
|
||||
try:
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import (
|
||||
Qwen3NextDecoderLayer,
|
||||
)
|
||||
except ImportError:
|
||||
LOG.warning("Qwen3Next model not found, skipping patch")
|
||||
return
|
||||
|
||||
# Store original forward method
|
||||
original_decoder_forward = Qwen3NextDecoderLayer.forward
|
||||
|
||||
def patched_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Token Mixer
|
||||
if self.layer_type == "linear_attention":
|
||||
hidden_states = self.linear_attn(
|
||||
hidden_states=hidden_states,
|
||||
cache_params=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
# For the MoE layers, we need to unpack
|
||||
if isinstance(hidden_states, Tuple):
|
||||
hidden_states, _ = hidden_states
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
# Apply the patches
|
||||
Qwen3NextDecoderLayer.forward = patched_decoder_forward
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original forward method"""
|
||||
Qwen3NextDecoderLayer.forward = original_decoder_forward
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
def patch_qwen3_next_gateddelta_layer():
|
||||
"""Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule"""
|
||||
try:
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import (
|
||||
Qwen3NextDynamicCache,
|
||||
Qwen3NextGatedDeltaNet,
|
||||
apply_mask_to_padding_states,
|
||||
)
|
||||
except ImportError:
|
||||
LOG.warning("Qwen3Next model not found, skipping patch")
|
||||
return
|
||||
|
||||
# Store original forward method
|
||||
original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward
|
||||
|
||||
def patched_gated_delta_net_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params: Optional[Qwen3NextDynamicCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
|
||||
# Set up dimensions for reshapes later
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
use_precomputed_states = (
|
||||
cache_params is not None
|
||||
and cache_params.has_previous_state
|
||||
and seq_len == 1
|
||||
and cache_position is not None
|
||||
)
|
||||
|
||||
# getting projected states from cache if it exists
|
||||
if cache_params is not None:
|
||||
conv_state = cache_params.conv_states[self.layer_idx]
|
||||
recurrent_state = cache_params.recurrent_states[self.layer_idx]
|
||||
|
||||
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
||||
projected_states_ba = self.in_proj_ba(hidden_states)
|
||||
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
||||
projected_states_qkvz, projected_states_ba
|
||||
)
|
||||
query, key, value = (
|
||||
x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)
|
||||
)
|
||||
|
||||
mixed_qkv = torch.cat((query, key, value), dim=-1)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
|
||||
if use_precomputed_states:
|
||||
# 2. Convolution sequence transformation
|
||||
# NOTE: the conv state is updated in `causal_conv1d_update`
|
||||
mixed_qkv = self.causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
self.conv1d.weight.squeeze(1),
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
)
|
||||
else:
|
||||
if cache_params is not None:
|
||||
conv_state = F.pad(
|
||||
mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0)
|
||||
)
|
||||
cache_params.conv_states[self.layer_idx] = conv_state
|
||||
if self.causal_conv1d_fn is not None:
|
||||
mixed_qkv = self.causal_conv1d_fn(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
seq_idx=None,
|
||||
)
|
||||
else:
|
||||
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
|
||||
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
self.key_dim,
|
||||
self.key_dim,
|
||||
self.value_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
|
||||
key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
|
||||
value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
|
||||
|
||||
beta = b.sigmoid()
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
if self.num_v_heads // self.num_k_heads > 1:
|
||||
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
|
||||
if not use_precomputed_states:
|
||||
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
|
||||
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=cache_params is not None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
else:
|
||||
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=recurrent_state,
|
||||
output_final_state=cache_params is not None,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
# Update cache
|
||||
if cache_params is not None:
|
||||
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
|
||||
|
||||
z_shape_og = z.shape
|
||||
# reshape input data into 2D tensor
|
||||
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(z_shape_og)
|
||||
core_attn_out = core_attn_out.reshape(
|
||||
core_attn_out.shape[0], core_attn_out.shape[1], -1
|
||||
)
|
||||
|
||||
output = self.out_proj(core_attn_out)
|
||||
return output
|
||||
|
||||
# Apply the patches
|
||||
Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original forward method"""
|
||||
Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
def patch_qwen3_next_imports():
|
||||
"""Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available."""
|
||||
try:
|
||||
import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling
|
||||
except ImportError:
|
||||
LOG.warning("Qwen3Next model not found, skipping import patch")
|
||||
return
|
||||
|
||||
# Save original values for unpatch
|
||||
original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None)
|
||||
original_chunk_gated_delta_rule = getattr(
|
||||
qwen3_modeling, "chunk_gated_delta_rule", None
|
||||
)
|
||||
original_fused_recurrent_gated_delta_rule = getattr(
|
||||
qwen3_modeling, "fused_recurrent_gated_delta_rule", None
|
||||
)
|
||||
original_is_fast_path_available = getattr(
|
||||
qwen3_modeling, "is_fast_path_available", False
|
||||
)
|
||||
|
||||
try:
|
||||
from fla.modules import FusedRMSNormGated
|
||||
from fla.ops.gated_delta_rule import (
|
||||
chunk_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
|
||||
qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated
|
||||
qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule
|
||||
qwen3_modeling.fused_recurrent_gated_delta_rule = (
|
||||
fused_recurrent_gated_delta_rule
|
||||
)
|
||||
|
||||
# Force is_fast_path_available to be True
|
||||
# fla has triton kernels for causal_conv1d
|
||||
qwen3_modeling.is_fast_path_available = True
|
||||
except ImportError:
|
||||
qwen3_modeling.chunk_gated_delta_rule = None
|
||||
qwen3_modeling.fused_recurrent_gated_delta_rule = None
|
||||
qwen3_modeling.FusedRMSNormGated = None
|
||||
|
||||
def unpatch():
|
||||
"""Restore the original import values"""
|
||||
qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated
|
||||
qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule
|
||||
qwen3_modeling.fused_recurrent_gated_delta_rule = (
|
||||
original_fused_recurrent_gated_delta_rule
|
||||
)
|
||||
qwen3_modeling.is_fast_path_available = original_is_fast_path_available
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
def patch_qwen3_next_modeling_packing():
|
||||
"""Apply all Qwen3Next model patches."""
|
||||
patch_qwen3_next_imports()
|
||||
patch_qwen3_next_decoder_layer()
|
||||
patch_qwen3_next_gateddelta_layer()
|
||||
|
||||
LOG.info("Applied Qwen3Next patch for packing")
|
||||
@@ -21,6 +21,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"qwen2_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
|
||||
111
tests/monkeypatch/test_qwen3_next_modeling_patch.py
Normal file
111
tests/monkeypatch/test_qwen3_next_modeling_patch.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Integration tests for Qwen3 Next modeling patches."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if qwen3_next not available
|
||||
qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next")
|
||||
|
||||
|
||||
class TestQwen3NextModelingPatchIntegration:
|
||||
"""Test Qwen3 Next modeling patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_decoder_layer_patch(self):
|
||||
"""Test that Qwen3Next decoder layer patch can be applied."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_decoder_layer,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_forward = qwen3_next.Qwen3NextDecoderLayer.forward
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = patch_qwen3_next_decoder_layer()
|
||||
|
||||
# Verify patch was applied
|
||||
assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (
|
||||
"decoder layer forward method was not patched"
|
||||
)
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
|
||||
# Test unpatch function
|
||||
if unpatch_fn:
|
||||
unpatch_fn()
|
||||
assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (
|
||||
"unpatch function did not restore original method"
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_gateddelta_layer_patch(self):
|
||||
"""Test that Qwen3Next GatedDeltaNet patch can be applied."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_gateddelta_layer,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = patch_qwen3_next_gateddelta_layer()
|
||||
|
||||
# Verify patch was applied
|
||||
assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (
|
||||
"GatedDeltaNet forward method was not patched"
|
||||
)
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
|
||||
# Test unpatch function
|
||||
if unpatch_fn:
|
||||
unpatch_fn()
|
||||
assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (
|
||||
"unpatch function did not restore original method"
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_imports_patch(self):
|
||||
"""Test that Qwen3Next imports patch can be applied without errors."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_imports,
|
||||
)
|
||||
|
||||
# Apply patch - should not raise any exceptions even if modules unavailable
|
||||
unpatch_fn = patch_qwen3_next_imports()
|
||||
|
||||
# Test that unpatch function is returned (or None if skipped)
|
||||
assert unpatch_fn is None or callable(unpatch_fn), (
|
||||
"patch_qwen3_next_imports should return None or callable unpatch function"
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_modeling_packing_patch(self):
|
||||
"""Test that all Qwen3Next modeling patches can be applied together."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_modeling_packing,
|
||||
)
|
||||
|
||||
# This should not raise any exceptions
|
||||
patch_qwen3_next_modeling_packing()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_cu_seqlens_utility():
|
||||
"""Test the get_cu_seqlens utility function."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens
|
||||
|
||||
# Test with simple position_ids
|
||||
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
|
||||
cu_seqlens = get_cu_seqlens(position_ids)
|
||||
assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype"
|
||||
|
||||
# Should return tensor with start positions and total length
|
||||
expected = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||
assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"
|
||||
Reference in New Issue
Block a user