diff --git a/examples/qwen3-next/README.md b/examples/qwen3-next/README.md index 3c3a26a76..df87ca725 100644 --- a/examples/qwen3-next/README.md +++ b/examples/qwen3-next/README.md @@ -6,30 +6,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations ## Getting started -1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). - Here is an example of how to install from main for pip: - -```bash -# Ensure you have Pytorch installed (Pytorch 2.6.0 min) -git clone https://github.com/axolotl-ai-cloud/axolotl.git -cd axolotl - -pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja -pip3 install --no-build-isolation -e '.[flash-attn]' - -# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy -python scripts/cutcrossentropy_install.py | sh -``` - -2. Install Qwen3-Next transformers commit -```bash -pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654" -``` +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. 3. Install FLA for improved performance ```bash -pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2 +pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1 ``` 4. Run the finetuning example: @@ -38,7 +21,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2 axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml ``` -This config uses about 45.62 GiB VRAM. +This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM. Let us know how it goes. Happy finetuning! 🚀 diff --git a/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml index db841beab..f63b1d1ce 100644 --- a/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml +++ b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml @@ -9,6 +9,8 @@ plugins: load_in_8bit: false load_in_4bit: true +quantize_moe_experts: true + datasets: - path: fozziethebeat/alpaca_messages_2k_test type: chat_template @@ -25,7 +27,7 @@ sample_packing: true lora_r: 16 lora_alpha: 8 -lora_dropout: 0.05 +lora_dropout: 0 lora_target_modules: - linear_attn.in_proj_ba - linear_attn.in_proj_qkvz @@ -34,12 +36,19 @@ lora_target_modules: - shared_expert.down_proj - shared_expert.gate_proj - shared_expert_gate - - mlp.gate - q_proj - v_proj - k_proj - o_proj +# lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + wandb_project: wandb_entity: wandb_watch: diff --git a/src/axolotl/monkeypatch/models/qwen3_next/modeling.py b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py index d68992d0e..48570ba42 100644 --- a/src/axolotl/monkeypatch/models/qwen3_next/modeling.py +++ b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py @@ -9,6 +9,11 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +try: + from fla.modules.convolution import causal_conv1d as fla_causal_conv1d +except ImportError: + fla_causal_conv1d = None + def get_cu_seqlens(position_ids): """ @@ -137,6 +142,11 @@ def patch_qwen3_next_gateddelta_layer(): and cache_position is not None ) + # Compute cu_seqlens early for use by both causal_conv1d and chunk_gated_delta_rule + cu_seqlens = None + if not use_precomputed_states and position_ids is not None: + cu_seqlens = get_cu_seqlens(position_ids=position_ids) + # getting projected states from cache if it exists if cache_params is not None: conv_state = cache_params.conv_states[self.layer_idx] @@ -151,12 +161,11 @@ def patch_qwen3_next_gateddelta_layer(): x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value) ) - mixed_qkv = torch.cat((query, key, value), dim=-1) - mixed_qkv = mixed_qkv.transpose(1, 2) + mixed_qkv = torch.cat((query, key, value), dim=-1) # [B, T, D] if use_precomputed_states: - # 2. Convolution sequence transformation - # NOTE: the conv state is updated in `causal_conv1d_update` + # Inference single-token path: causal_conv1d_update expects [B, D, T] + mixed_qkv = mixed_qkv.transpose(1, 2) mixed_qkv = self.causal_conv1d_update( mixed_qkv, conv_state, @@ -164,24 +173,41 @@ def patch_qwen3_next_gateddelta_layer(): self.conv1d.bias, self.activation, ) + mixed_qkv = mixed_qkv.transpose(1, 2) else: if cache_params is not None: + # Cache state expects [B, D, T] for the inference update path + mixed_qkv_t = mixed_qkv.transpose(1, 2) conv_state = F.pad( - mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0) + mixed_qkv_t, + (self.conv_kernel_size - mixed_qkv_t.shape[-1], 0), ) cache_params.conv_states[self.layer_idx] = conv_state - if self.causal_conv1d_fn is not None: - mixed_qkv = self.causal_conv1d_fn( + + if fla_causal_conv1d is not None: + # FLA Triton causal_conv1d: [B, T, D] in/out, with cu_seqlens support + mixed_qkv, _ = fla_causal_conv1d( x=mixed_qkv, weight=self.conv1d.weight.squeeze(1), bias=self.conv1d.bias, activation=self.activation, - seq_idx=None, + cu_seqlens=cu_seqlens, ) else: + # PyTorch fallback (no cu_seqlens support) + if cu_seqlens is not None and cu_seqlens.shape[0] > batch_size + 1: + raise RuntimeError( + "Packed sequences require fla.modules.convolution.causal_conv1d " + "(cu_seqlens support). Install flash-linear-attention or disable packing." + ) + LOG.warning_once( + "FLA causal_conv1d not available. Falling back to PyTorch conv1d." + ) + mixed_qkv = mixed_qkv.transpose(1, 2) mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + mixed_qkv = mixed_qkv.transpose(1, 2) - mixed_qkv = mixed_qkv.transpose(1, 2) + # mixed_qkv is [B, T, D] in all paths query, key, value = torch.split( mixed_qkv, [ @@ -203,7 +229,6 @@ def patch_qwen3_next_gateddelta_layer(): key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) if not use_precomputed_states: - cu_seqlens = get_cu_seqlens(position_ids=position_ids) core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( query, key,