add: qwen 3.5 (#3442)

* add: qwen 3.5

* test for qwen , patch

* lint

* qwen3 fix on main

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* moe config

* config moe

* configs and chore

* Update examples/qwen3.5/122b-a10b-moe-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/qwen3.5/35b-a3b-moe-qlora.yaml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* chore for qwen + vlm patch

* chore lint

* qwen lint

* 3_5_moe

* Update examples/qwen3.5/README.md

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
VED
2026-03-06 20:01:00 +05:30
committed by GitHub
parent 6c8c73e5a4
commit c119382337
13 changed files with 819 additions and 0 deletions

View File

@@ -0,0 +1,71 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
#lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,72 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Note: Qwen3.5 is an early-fusion VLM (image+text). This config fine-tunes
# the text-only path. For multimodal (image+text) fine-tuning, add image
# columns to your dataset following axolotl's multimodal dataset format.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,70 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
#lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -0,0 +1,72 @@
base_model: Qwen/Qwen3.5-7B
processor_type: AutoProcessor
# Qwen3.5-7B and above are early-fusion VLMs (Qwen3_5ForConditionalGeneration).
# Vision and text tokens are processed together by the same transformer layers.
# Note: Qwen3.5-2B is a text-only model — the smallest VLM is Qwen3.5-7B.
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# Targets the language model attention and MLP layers.
# Qwen3.5 is early-fusion: all layers (including those seeing vision tokens) share
# the same transformer stack, so standard attention targets work for both modalities.
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment to also target the linear attention (GatedDeltaNet) projections:
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -0,0 +1,61 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35-68452f3bc6e4b7cfb4e1c803) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. Models from 7B onwards are early-fusion vision-language models (`Qwen3_5ForConditionalGeneration`), meaning vision and text tokens are processed through the same transformer stack. The 2B variant is text-only.
Available configs:
| Config | Model | Type |
|---|---|---|
| `27b-qlora.yaml` | Qwen3.5-27B | Dense VLM, text-only path |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only path |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only path |
| `7b-lora-vision.yaml` | Qwen3.5-7B | Vision+text (multimodal) |
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Run a finetuning example:
```bash
# Dense 27B text-only (QLoRA, ~47 GiB VRAM with sample packing)
axolotl train examples/qwen3.5/27b-qlora.yaml
# MoE 35B-A3B text-only (QLoRA)
axolotl train examples/qwen3.5/35b-a3b-moe-qlora.yaml
# MoE 122B-A10B text-only (QLoRA)
axolotl train examples/qwen3.5/122b-a10b-moe-qlora.yaml
# 7B vision+text (LoRA, multimodal dataset)
axolotl train examples/qwen3.5/7b-lora-vision.yaml
```
### TIPS
- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`.
- You can run a full finetuning by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `7b-lora-vision.yaml`.
- The Gated DeltaNet linear attention layers (`linear_attn.*`) can optionally be added to `lora_target_modules` — they are commented out by default.
## Optimization Guides
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
## Related Resources
- [Qwen3.5 Blog](https://qwenlm.github.io/blog/qwen3.5/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -12,6 +12,7 @@ MOE_ARCH_BLOCK = {
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",

View File

@@ -246,6 +246,31 @@ class PatchManager:
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "qwen3_5" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_modeling_packing,
)
patch_qwen3_5_modeling_packing()
if self.cfg.model_config_type == "qwen3_5_moe" and self.cfg.sample_packing:
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_moe_modeling_packing,
)
patch_qwen3_5_moe_modeling_packing()
if (
self.cfg.model_config_type in ["qwen3_5", "qwen3_5_moe"]
and self.cfg.is_multimodal
and self.cfg.flash_attention
):
from axolotl.monkeypatch.models.qwen3_5.modeling import (
patch_qwen3_5_vlm_flash_attention,
)
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,

View File

@@ -0,0 +1,291 @@
"""Monkeypatch for Qwen3_5 and Qwen3_5Moe models to pass position_ids to linear attention."""
import importlib
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
from fla.modules.convolution import (
causal_conv1d as fla_causal_conv1d, # FLA >= 0.4.1
)
except ImportError:
try:
from fla.modules.conv import causal_conv1d as fla_causal_conv1d # FLA < 0.4.1
except ImportError:
fla_causal_conv1d = None
def get_cu_seqlens(position_ids):
"""
Compute cumulative sequence lengths from position_ids for FLA varlen kernels.
Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids.
https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316
Qwen3.5 uses MRoPE: position_ids arrive as [axes, B, T]. All axes carry the
same temporal positions, so axis 0 is used to recover the [B, T] layout.
See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
"""
if position_ids.ndim == 3:
position_ids = position_ids[0]
tensor_kwargs = {"dtype": torch.long, "device": position_ids.device}
position_ids = position_ids.reshape(-1)
indices_q = (position_ids == 0).nonzero().reshape(-1)
return torch.cat(
(
indices_q.to(**tensor_kwargs),
torch.tensor(position_ids.size(), **tensor_kwargs),
)
)
def _inject_fla_kernels(module) -> None:
"""Inject FLA kernels into a modeling module, bypassing is_flash_linear_attention_available."""
try:
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import (
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
module.FusedRMSNormGated = FusedRMSNormGated
module.chunk_gated_delta_rule = chunk_gated_delta_rule
module.fused_recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule
module.is_fast_path_available = True
except ImportError:
module.chunk_gated_delta_rule = None
module.fused_recurrent_gated_delta_rule = None
module.FusedRMSNormGated = None
def _patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.FloatTensor:
"""Decoder layer forward that passes position_ids through to linear attention."""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids,
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
def _make_qwen3_5_gated_delta_forward(apply_mask_fn):
"""Factory for patched Qwen3_5/Qwen3_5Moe GatedDeltaNet forward with packing support."""
def patched_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
):
hidden_states = apply_mask_fn(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
cu_seqlens = None
if not use_precomputed_states and position_ids is not None:
cu_seqlens = get_cu_seqlens(position_ids=position_ids)
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
# mixed_qkv stays [B, T, D]; only transposed inside paths that require [B, D, T]
mixed_qkv = self.in_proj_qkv(hidden_states) # [B, T, D]
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
if use_precomputed_states:
mixed_qkv = self.causal_conv1d_update(
mixed_qkv.transpose(1, 2),
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
).transpose(1, 2)
else:
if cache_params is not None:
mixed_qkv_t = mixed_qkv.transpose(1, 2)
cache_params.conv_states[self.layer_idx] = F.pad(
mixed_qkv_t,
(self.conv_kernel_size - mixed_qkv_t.shape[-1], 0),
)
if fla_causal_conv1d is not None and cu_seqlens is not None:
# FLA varlen kernel for packed sequences; input must be contiguous [B, T, D]
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)
else:
if cu_seqlens is not None and fla_causal_conv1d is None:
raise RuntimeError(
"Packed sequences require fla.modules.convolution.causal_conv1d "
"(cu_seqlens support). Install flash-linear-attention or disable packing."
)
mixed_qkv = F.silu(
self.conv1d(mixed_qkv.transpose(1, 2))[:, :, :seq_len]
).transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[self.key_dim, self.key_dim, self.value_dim],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g.to(dtype=query.dtype),
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
# torch_chunk_gated_delta_rule fallback does not accept cu_seqlens
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g.to(dtype=query.dtype),
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
return self.out_proj(core_attn_out)
return patched_forward
def _apply_packing_patches(model_type: str, cls_prefix: str, forward_factory) -> None:
module_name = f"transformers.models.{model_type}.modeling_{model_type}"
try:
module = importlib.import_module(module_name)
except ImportError:
LOG.warning(f"{model_type} not found in transformers, skipping packing patches")
return
_inject_fla_kernels(module)
getattr(module, f"{cls_prefix}DecoderLayer").forward = _patched_decoder_forward
gated_cls = getattr(module, f"{cls_prefix}GatedDeltaNet")
gated_cls.forward = forward_factory(module.apply_mask_to_padding_states)
LOG.info(
f"Applied {cls_prefix} packing patch "
f"(fla_causal_conv1d={'available' if fla_causal_conv1d else 'unavailable'})"
)
def patch_qwen3_5_modeling_packing():
_apply_packing_patches("qwen3_5", "Qwen3_5", _make_qwen3_5_gated_delta_forward)
def patch_qwen3_5_moe_modeling_packing():
_apply_packing_patches(
"qwen3_5_moe", "Qwen3_5Moe", _make_qwen3_5_gated_delta_forward
)
def patch_qwen3_5_vlm_flash_attention():
"""
Patch _is_packed_sequence to handle Qwen3.5's 3-D MRoPE position_ids.
transformers passes position_ids as [axes, B, T] to decoder layers, but
_is_packed_sequence only handles 2-D tensors and mis-classifies the 3-D
shape as a packed-sequence indicator, causing CUDA errors in the varlen path.
"""
try:
import transformers.modeling_flash_attention_utils as fa_utils
_original = fa_utils._is_packed_sequence
def _patched(position_ids, batch_size):
if position_ids is not None and position_ids.ndim != 2:
return False
return _original(position_ids, batch_size)
fa_utils._is_packed_sequence = _patched
LOG.info("Applied Qwen3.5 VLM flash-attention patch (3-D MRoPE position_ids)")
except Exception as exc: # pragma: no cover
LOG.warning(f"Failed to apply Qwen3.5 VLM flash-attention patch: {exc}")

View File

@@ -22,6 +22,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"qwen3",
"qwen3_moe",
"qwen3_next",
"qwen3_5",
"qwen3_5_moe",
"falcon",
"phi",
"phi3",

View File

@@ -258,6 +258,32 @@ class Qwen2VLProcessingStrategy(ProcessingStrategy):
)
class Qwen3_5ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Qwen3.5 (early-fusion VLM)"""
def __init__(
self,
processor: ProcessorMixin,
chat_template: Optional[str] = None,
image_size: int | tuple[int, int] | None = None,
image_resize_algorithm: Resampling | None = None,
):
super().__init__(processor, chat_template, image_size, image_resize_algorithm)
self.image_token = "<|image_pad|>" # nosec
self.image_token_id = processor.tokenizer.convert_tokens_to_ids(
self.image_token
)
self.video_token = "<|video_pad|>" # nosec
self.video_token_id = processor.tokenizer.convert_tokens_to_ids(
self.video_token
)
def process_labels(self, input_ids):
labels = super().process_labels(input_ids)
labels[labels == self.video_token_id] = -100
return labels
class Gemma3ProcessingStrategy(ProcessingStrategy):
"""Processing Strategy class for Gemma3"""
@@ -562,6 +588,10 @@ def get_processing_strategy(
return Qwen2VLProcessingStrategy(
**processing_kwargs,
)
if chat_template_type in ["qwen3_5", "qwen3_5_moe"]:
return Qwen3_5ProcessingStrategy(
**processing_kwargs,
)
if chat_template_type == "gemma3":
return Gemma3ProcessingStrategy(
**processing_kwargs,

View File

@@ -0,0 +1,123 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{#- Determine the real last index: use provided value or default to messages length - 1 #}
{%- if real_last_index is defined and real_last_index is not none %}
{%- set ns.real_last_index = real_last_index %}
{%- else %}
{%- set ns.real_last_index = messages|length - 1 %}
{%- endif %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if message['content'] is string %}
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- else %}
{%- if ns.multi_step_tool and message.role == "user" %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' }}
{%- if message['content'] is string %}
{{- message.content }}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
{{- '<|vision_start|><|image_pad|><|vision_end|>' }}
{%- elif content['type'] == 'video' or 'video' in content %}
{{- '<|vision_start|><|video_pad|><|vision_end|>' }}
{%- elif 'text' in content %}
{{- content['text'] }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "assistant" %}
{%- if message['content'] is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- for item in message['content'] %}
{%- if 'text' in item %}
{%- set content = content + item['text'] %}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.index0 == ns.real_last_index or (loop.index0 != ns.real_last_index and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- else %}
{{- '<think>\n\n' }}
{%- endif %}
{%- endif %}

View File

@@ -59,6 +59,7 @@ class ChatTemplate(str, Enum):
jinja = "jinja"
qwen_25 = "qwen_25"
qwen3 = "qwen3"
qwen3_5 = "qwen3_5"
falcon_h1 = "falcon_h1"
tokenizer_default = "tokenizer_default"
exaone = "exaone"