From c119382337645796dbc8e82e2b212f14a58e69bd Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Fri, 6 Mar 2026 20:01:00 +0530 Subject: [PATCH] add: qwen 3.5 (#3442) * add: qwen 3.5 * test for qwen , patch * lint * qwen3 fix on main * Apply suggestions from code review Co-authored-by: NanoCode012 * moe config * config moe * configs and chore * Update examples/qwen3.5/122b-a10b-moe-qlora.yaml Co-authored-by: NanoCode012 * Update examples/qwen3.5/35b-a3b-moe-qlora.yaml Co-authored-by: NanoCode012 * chore for qwen + vlm patch * chore lint * qwen lint * 3_5_moe * Update examples/qwen3.5/README.md --------- Co-authored-by: NanoCode012 --- examples/qwen3.5/122b-a10b-moe-qlora.yaml | 71 +++++ examples/qwen3.5/27b-qlora.yaml | 72 +++++ examples/qwen3.5/35b-a3b-moe-qlora.yaml | 70 +++++ examples/qwen3.5/7b-lora-vision.yaml | 72 +++++ examples/qwen3.5/README.md | 61 ++++ src/axolotl/common/architectures.py | 1 + src/axolotl/loaders/patch_manager.py | 25 ++ .../monkeypatch/models/qwen3_5/__init__.py | 0 .../monkeypatch/models/qwen3_5/modeling.py | 291 ++++++++++++++++++ src/axolotl/monkeypatch/multipack.py | 2 + src/axolotl/processing_strategies.py | 30 ++ .../chat_templates/templates/qwen3_5.jinja | 123 ++++++++ src/axolotl/utils/schemas/enums.py | 1 + 13 files changed, 819 insertions(+) create mode 100644 examples/qwen3.5/122b-a10b-moe-qlora.yaml create mode 100644 examples/qwen3.5/27b-qlora.yaml create mode 100644 examples/qwen3.5/35b-a3b-moe-qlora.yaml create mode 100644 examples/qwen3.5/7b-lora-vision.yaml create mode 100644 examples/qwen3.5/README.md create mode 100644 src/axolotl/monkeypatch/models/qwen3_5/__init__.py create mode 100644 src/axolotl/monkeypatch/models/qwen3_5/modeling.py create mode 100644 src/axolotl/utils/chat_templates/templates/qwen3_5.jinja diff --git a/examples/qwen3.5/122b-a10b-moe-qlora.yaml b/examples/qwen3.5/122b-a10b-moe-qlora.yaml new file mode 100644 index 000000000..e9cbf80ce --- /dev/null +++ b/examples/qwen3.5/122b-a10b-moe-qlora.yaml @@ -0,0 +1,71 @@ +base_model: Qwen/Qwen3.5-122B-A10B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +strict: false + +chat_template: qwen3_5 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value +val_set_size: 0.0 +output_dir: ./outputs/out +dataset_prepared_path: last_run_prepared + +sequence_len: 2048 +sample_packing: true + +load_in_4bit: true +quantize_moe_experts: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + +#lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + + +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/qwen3.5/27b-qlora.yaml b/examples/qwen3.5/27b-qlora.yaml new file mode 100644 index 000000000..2ba1c4ed7 --- /dev/null +++ b/examples/qwen3.5/27b-qlora.yaml @@ -0,0 +1,72 @@ +base_model: Qwen/Qwen3.5-27B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name +# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes +# the text-only path. For multimodal (image+text) fine-tuning, add image +# columns to your dataset following axolotl's multimodal dataset format. + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +strict: false + +chat_template: qwen3_5 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value +val_set_size: 0.0 +output_dir: ./outputs/out +dataset_prepared_path: last_run_prepared + +sequence_len: 2048 +sample_packing: true + +load_in_4bit: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - down_proj + - up_proj + # Uncomment below to also target the linear attention projections. + # These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific). + # - linear_attn.in_proj_qkv + # - linear_attn.in_proj_z + # - linear_attn.out_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/qwen3.5/35b-a3b-moe-qlora.yaml b/examples/qwen3.5/35b-a3b-moe-qlora.yaml new file mode 100644 index 000000000..462babf0b --- /dev/null +++ b/examples/qwen3.5/35b-a3b-moe-qlora.yaml @@ -0,0 +1,70 @@ +base_model: Qwen/Qwen3.5-35B-A3B + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin +strict: false + +chat_template: qwen3_5 +datasets: + - path: mlabonne/FineTome-100k + type: chat_template + split: train[:20%] + field_messages: conversations + message_property_mappings: + role: from + content: value +val_set_size: 0.0 +output_dir: ./outputs/out +dataset_prepared_path: last_run_prepared + +sequence_len: 2048 +sample_packing: true + +load_in_4bit: true +quantize_moe_experts: true +adapter: qlora +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0 +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + +#lora_target_parameters: +# - mlp.experts.gate_up_proj +# - mlp.experts.down_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_4bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +lora_mlp_kernel: false +lora_qkv_kernel: false +lora_o_kernel: false + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +special_tokens: diff --git a/examples/qwen3.5/7b-lora-vision.yaml b/examples/qwen3.5/7b-lora-vision.yaml new file mode 100644 index 000000000..79179ec96 --- /dev/null +++ b/examples/qwen3.5/7b-lora-vision.yaml @@ -0,0 +1,72 @@ +base_model: Qwen/Qwen3.5-7B +processor_type: AutoProcessor + +# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration). +# Vision and text tokens are processed together by the same transformer layers. +# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B. + +# These 3 lines are required for vision/multimodal training +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: qwen3_5 +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +# Targets the language model attention and MLP layers. +# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share +# the same transformer stack, so standard attention targets work for both modalities. +lora_target_modules: + - q_proj + - k_proj + - v_proj + - o_proj + - down_proj + - up_proj + # Uncomment to also target the linear attention (GatedDeltaNet) projections: + # - linear_attn.in_proj_qkv + # - linear_attn.in_proj_z + # - linear_attn.out_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/examples/qwen3.5/README.md b/examples/qwen3.5/README.md new file mode 100644 index 000000000..8a2f9b4bd --- /dev/null +++ b/examples/qwen3.5/README.md @@ -0,0 +1,61 @@ +# Finetune Qwen3.5 with Axolotl + +[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only. + +Available configs: + +| Config | Model | Type | +|---|---|---| +| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path | +| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path | +| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path | +| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) | + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). + +2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage. + +3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers: +```bash +pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1 +``` +> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there. + +4. Run a finetuning example: + +```bash +# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing) +axolotl train examples/qwen3.5/27b-qlora.yaml + +# MoE 35B-A3B text-only (QLoRA) +axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml + +# MoE 122B-A10B text-only (QLoRA) +axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml + +# 7B vision+text (LoRA, multimodal dataset) +axolotl train examples/qwen3.5/7b-lora-vision.yaml +``` + +### TIPS + +- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`. +- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below. +- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). +- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`. +- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default. + +## Optimization Guides + +- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html) + +## Related Resources + +- [Qwen3.5 Blog](https://qwenlm.github.io/blog/qwen3.5/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/src/axolotl/common/architectures.py b/src/axolotl/common/architectures.py index a409ed9f4..0e1de3017 100644 --- a/src/axolotl/common/architectures.py +++ b/src/axolotl/common/architectures.py @@ -12,6 +12,7 @@ MOE_ARCH_BLOCK = { "mixtral": "MixtralSparseMoeBlock", "qwen2_moe": "Qwen2MoeSparseMoeBlock", "qwen3_moe": "Qwen3MoeSparseMoeBlock", + "qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock", "qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock", "deepseek_v2": "DeepseekV2MoE", "deepseek_v3": "DeepseekV3MoE", diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 51bcaeba4..f94626ec7 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -246,6 +246,31 @@ class PatchManager: patch_qwen3_next_modeling_packing() + if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing: + from axolotl.monkeypatch.models.qwen3_5.modeling import ( + patch_qwen3_5_modeling_packing, + ) + + patch_qwen3_5_modeling_packing() + + if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing: + from axolotl.monkeypatch.models.qwen3_5.modeling import ( + patch_qwen3_5_moe_modeling_packing, + ) + + patch_qwen3_5_moe_modeling_packing() + + if ( + self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"] + and self.cfg.is_multimodal + and self.cfg.flash_attention + ): + from axolotl.monkeypatch.models.qwen3_5.modeling import ( + patch_qwen3_5_vlm_flash_attention, + ) + + patch_qwen3_5_vlm_flash_attention() + if self.cfg.model_config_type == "kimi_linear": from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import ( patch_kimi_model, diff --git a/src/axolotl/monkeypatch/models/qwen3_5/__init__.py b/src/axolotl/monkeypatch/models/qwen3_5/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/monkeypatch/models/qwen3_5/modeling.py b/src/axolotl/monkeypatch/models/qwen3_5/modeling.py new file mode 100644 index 000000000..f88f60555 --- /dev/null +++ b/src/axolotl/monkeypatch/models/qwen3_5/modeling.py @@ -0,0 +1,291 @@ +"""Monkeypatch for Qwen3_5 and Qwen3_5Moe models to pass position_ids to linear attention.""" + +import importlib +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +try: + from fla.modules.convolution import ( + causal_conv1d as fla_causal_conv1d, # FLA >= 0.4.1 + ) +except ImportError: + try: + from fla.modules.conv import causal_conv1d as fla_causal_conv1d # FLA < 0.4.1 + except ImportError: + fla_causal_conv1d = None + + +def get_cu_seqlens(position_ids): + """ + Compute cumulative sequence lengths from position_ids for FLA varlen kernels. + + Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids. + https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316 + + Qwen3.5 uses MRoPE: position_ids arrive as [axes, B, T]. All axes carry the + same temporal positions, so axis 0 is used to recover the [B, T] layout. + See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/modeling_qwen3_5.py + """ + if position_ids.ndim == 3: + position_ids = position_ids[0] + + tensor_kwargs = {"dtype": torch.long, "device": position_ids.device} + position_ids = position_ids.reshape(-1) + indices_q = (position_ids == 0).nonzero().reshape(-1) + return torch.cat( + ( + indices_q.to(**tensor_kwargs), + torch.tensor(position_ids.size(), **tensor_kwargs), + ) + ) + + +def _inject_fla_kernels(module) -> None: + """Inject FLA kernels into a modeling module, bypassing is_flash_linear_attention_available.""" + try: + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, + ) + + module.FusedRMSNormGated = FusedRMSNormGated + module.chunk_gated_delta_rule = chunk_gated_delta_rule + module.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule + module.is_fast_path_available = True + except ImportError: + module.chunk_gated_delta_rule = None + module.fused_recurrent_gated_delta_rule = None + module.FusedRMSNormGated = None + + +def _patched_decoder_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values=None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> torch.FloatTensor: + """Decoder layer forward that passes position_ids through to linear attention.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + position_ids=position_ids, + ) + elif self.layer_type == "full_attention": + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits) + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + +def _make_qwen3_5_gated_delta_forward(apply_mask_fn): + """Factory for patched Qwen3_5/Qwen3_5Moe GatedDeltaNet forward with packing support.""" + + def patched_forward( + self, + hidden_states: torch.Tensor, + cache_params=None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ): + hidden_states = apply_mask_fn(hidden_states, attention_mask) + + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + cu_seqlens = None + if not use_precomputed_states and position_ids is not None: + cu_seqlens = get_cu_seqlens(position_ids=position_ids) + + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + # mixed_qkv stays [B, T, D]; only transposed inside paths that require [B, D, T] + mixed_qkv = self.in_proj_qkv(hidden_states) # [B, T, D] + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if use_precomputed_states: + mixed_qkv = self.causal_conv1d_update( + mixed_qkv.transpose(1, 2), + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ).transpose(1, 2) + else: + if cache_params is not None: + mixed_qkv_t = mixed_qkv.transpose(1, 2) + cache_params.conv_states[self.layer_idx] = F.pad( + mixed_qkv_t, + (self.conv_kernel_size - mixed_qkv_t.shape[-1], 0), + ) + + if fla_causal_conv1d is not None and cu_seqlens is not None: + # FLA varlen kernel for packed sequences; input must be contiguous [B, T, D] + mixed_qkv, _ = fla_causal_conv1d( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + cu_seqlens=cu_seqlens, + ) + else: + if cu_seqlens is not None and fla_causal_conv1d is None: + raise RuntimeError( + "Packed sequences require fla.modules.convolution.causal_conv1d " + "(cu_seqlens support). Install flash-linear-attention or disable packing." + ) + mixed_qkv = F.silu( + self.conv1d(mixed_qkv.transpose(1, 2))[:, :, :seq_len] + ).transpose(1, 2) + + query, key, value = torch.split( + mixed_qkv, + [self.key_dim, self.key_dim, self.value_dim], + dim=-1, + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g.to(dtype=query.dtype), + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + # torch_chunk_gated_delta_rule fallback does not accept cu_seqlens + **({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}), + ) + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g.to(dtype=query.dtype), + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + return self.out_proj(core_attn_out) + + return patched_forward + + +def _apply_packing_patches(model_type: str, cls_prefix: str, forward_factory) -> None: + module_name = f"transformers.models.{model_type}.modeling_{model_type}" + + try: + module = importlib.import_module(module_name) + except ImportError: + LOG.warning(f"{model_type} not found in transformers, skipping packing patches") + return + + _inject_fla_kernels(module) + getattr(module, f"{cls_prefix}DecoderLayer").forward = _patched_decoder_forward + gated_cls = getattr(module, f"{cls_prefix}GatedDeltaNet") + gated_cls.forward = forward_factory(module.apply_mask_to_padding_states) + + LOG.info( + f"Applied {cls_prefix} packing patch " + f"(fla_causal_conv1d={'available' if fla_causal_conv1d else 'unavailable'})" + ) + + +def patch_qwen3_5_modeling_packing(): + _apply_packing_patches("qwen3_5", "Qwen3_5", _make_qwen3_5_gated_delta_forward) + + +def patch_qwen3_5_moe_modeling_packing(): + _apply_packing_patches( + "qwen3_5_moe", "Qwen3_5Moe", _make_qwen3_5_gated_delta_forward + ) + + +def patch_qwen3_5_vlm_flash_attention(): + """ + Patch _is_packed_sequence to handle Qwen3.5's 3-D MRoPE position_ids. + + transformers passes position_ids as [axes, B, T] to decoder layers, but + _is_packed_sequence only handles 2-D tensors and mis-classifies the 3-D + shape as a packed-sequence indicator, causing CUDA errors in the varlen path. + """ + try: + import transformers.modeling_flash_attention_utils as fa_utils + + _original = fa_utils._is_packed_sequence + + def _patched(position_ids, batch_size): + if position_ids is not None and position_ids.ndim != 2: + return False + return _original(position_ids, batch_size) + + fa_utils._is_packed_sequence = _patched + LOG.info("Applied Qwen3.5 VLM flash-attention patch (3-D MRoPE position_ids)") + except Exception as exc: # pragma: no cover + LOG.warning(f"Failed to apply Qwen3.5 VLM flash-attention patch: {exc}") diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 3208325eb..cad6039bd 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -22,6 +22,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "qwen3", "qwen3_moe", "qwen3_next", + "qwen3_5", + "qwen3_5_moe", "falcon", "phi", "phi3", diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index c8b153e6d..cb1f9d984 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -258,6 +258,32 @@ class Qwen2VLProcessingStrategy(ProcessingStrategy): ) +class Qwen3_5ProcessingStrategy(ProcessingStrategy): + """Processing Strategy class for Qwen3.5 (early-fusion VLM)""" + + def __init__( + self, + processor: ProcessorMixin, + chat_template: Optional[str] = None, + image_size: int | tuple[int, int] | None = None, + image_resize_algorithm: Resampling | None = None, + ): + super().__init__(processor, chat_template, image_size, image_resize_algorithm) + self.image_token = "<|image_pad|>" # nosec + self.image_token_id = processor.tokenizer.convert_tokens_to_ids( + self.image_token + ) + self.video_token = "<|video_pad|>" # nosec + self.video_token_id = processor.tokenizer.convert_tokens_to_ids( + self.video_token + ) + + def process_labels(self, input_ids): + labels = super().process_labels(input_ids) + labels[labels == self.video_token_id] = -100 + return labels + + class Gemma3ProcessingStrategy(ProcessingStrategy): """Processing Strategy class for Gemma3""" @@ -562,6 +588,10 @@ def get_processing_strategy( return Qwen2VLProcessingStrategy( **processing_kwargs, ) + if chat_template_type in ["qwen3_5", "qwen3_5_moe"]: + return Qwen3_5ProcessingStrategy( + **processing_kwargs, + ) if chat_template_type == "gemma3": return Gemma3ProcessingStrategy( **processing_kwargs, diff --git a/src/axolotl/utils/chat_templates/templates/qwen3_5.jinja b/src/axolotl/utils/chat_templates/templates/qwen3_5.jinja new file mode 100644 index 000000000..21f5733ed --- /dev/null +++ b/src/axolotl/utils/chat_templates/templates/qwen3_5.jinja @@ -0,0 +1,123 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{#- Determine the real last index: use provided value or default to messages length - 1 #} +{%- if real_last_index is defined and real_last_index is not none %} + {%- set ns.real_last_index = real_last_index %} +{%- else %} + {%- set ns.real_last_index = messages|length - 1 %} +{%- endif %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if message['content'] is string %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- else %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- if message['content'] is string %} + {{- message.content }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' or 'image' in content or 'image_url' in content %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif content['type'] == 'video' or 'video' in content %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in content %} + {{- content['text'] }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "assistant" %} + {%- if message['content'] is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- for item in message['content'] %} + {%- if 'text' in item %} + {%- set content = content + item['text'] %} + {%- endif %} + {%- endfor %} + {%- endif %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} +{%- endif %} diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 893f23288..792f6f6de 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -59,6 +59,7 @@ class ChatTemplate(str, Enum): jinja = "jinja" qwen_25 = "qwen_25" qwen3 = "qwen3" + qwen3_5 = "qwen3_5" falcon_h1 = "falcon_h1" tokenizer_default = "tokenizer_default" exaone = "exaone"