diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index e63632e7c..b48331063 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -40,7 +40,7 @@ "%%capture\n", "# This step can take ~5-10 minutes to install dependencies\n", "!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n", - "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc\"" + "!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef\"" ] }, { diff --git a/examples/qwen3-next/README.md b/examples/qwen3-next/README.md new file mode 100644 index 000000000..eb0d5fd28 --- /dev/null +++ b/examples/qwen3-next/README.md @@ -0,0 +1,64 @@ +# Finetune Qwen3-Next with Axolotl + +[Qwen3-Next](https://huggingface.co/collections/Qwen/qwen3-next-68c25fd6838e585db8eeea9d) represents the next-generation foundation models optimized for extreme context length and large-scale parameter efficiency. The series introduces architectural innovations including Hybrid Attention (Gated DeltaNet + Gated Attention), High-Sparsity MoE with 1:50 activation ratio, and Multi-Token Prediction for enhanced performance and inference acceleration. + +This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking. + +## Getting started + +1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html). + + Here is an example of how to install from main for pip: + +```bash +# Ensure you have Pytorch installed (Pytorch 2.6.0 min) +git clone https://github.com/axolotl-ai-cloud/axolotl.git +cd axolotl + +pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja +pip3 install --no-build-isolation -e '.[flash-attn]' + +# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy +python scripts/cutcrossentropy_install.py | sh +``` + +2. Install Qwen3-Next transformers commit +```bash +pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654" +``` + +3. Install FLA for improved performance +```bash +pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2 +``` + +4. Run the finetuning example: + +```bash +axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml +``` + +This config uses about 41.7 GiB VRAM. + +Let us know how it goes. Happy finetuning! 🚀 + +### TIPS + +- For inference, you can experiment with `temperature: 0.7`, `top_p: 0.8`, `top_k: 20`, and `min_p: 0`. +- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config. See [Multi-GPU](#optimization-guides) section below. +- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html). +- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template). + +## Optimization Guides + +- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html) +- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html) + +## Related Resources + +- [Qwen3-Next Blog](https://qwenlm.github.io/blog/qwen3_next/) +- [Axolotl Docs](https://docs.axolotl.ai) +- [Axolotl Website](https://axolotl.ai) +- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl) +- [Axolotl Discord](https://discord.gg/7m9sfhzaf3) diff --git a/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml new file mode 100644 index 000000000..11481dcd3 --- /dev/null +++ b/examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml @@ -0,0 +1,60 @@ +base_model: Qwen/Qwen3-Next-80B-A3B-Instruct + +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +load_in_8bit: false +load_in_4bit: true + +datasets: + - path: fozziethebeat/alpaca_messages_2k_test + type: chat_template + +dataset_prepared_path: last_run_prepared +val_set_size: 0.1 +output_dir: ./outputs/lora-out + +adapter: qlora +lora_model_dir: + +sequence_len: 2048 +sample_packing: true + +lora_r: 16 +lora_alpha: 8 +lora_dropout: 0.05 +lora_target_modules: + - q_proj + - v_proj + - k_proj + - o_proj + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 2 +micro_batch_size: 2 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: false + +gradient_checkpointing: true +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index ada574805..dc117604a 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 2361dde4a..cc73eebb7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef" ``` ## Usage @@ -65,6 +65,7 @@ plugins: - qwen2_5_vl - qwen3 - qwen3_moe +- qwen3_next - smollm3 - seed_oss - voxtral diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index dad3f7f89..812baf33f 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -35,7 +35,7 @@ LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( "Please install Axolotl's fork of cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c564afc"`' + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c5aa3ef"`' ) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index a78f8b965..3d4b7b96b 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -169,6 +169,13 @@ class PatchManager: patch_llama4_linearized_modeling() + if self.cfg.model_config_type == "qwen3_next" and self.cfg.sample_packing: + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_modeling_packing, + ) + + patch_qwen3_next_modeling_packing() + if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type: from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import ( apply_mistral_tokenizer_image_patch, diff --git a/src/axolotl/monkeypatch/models/qwen3_next/__init__.py b/src/axolotl/monkeypatch/models/qwen3_next/__init__.py new file mode 100644 index 000000000..39bcd4115 --- /dev/null +++ b/src/axolotl/monkeypatch/models/qwen3_next/__init__.py @@ -0,0 +1 @@ +"""Qwen3_Next model monkeypatches.""" diff --git a/src/axolotl/monkeypatch/models/qwen3_next/modeling.py b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py new file mode 100644 index 000000000..d68992d0e --- /dev/null +++ b/src/axolotl/monkeypatch/models/qwen3_next/modeling.py @@ -0,0 +1,317 @@ +"""Monkeypatch for Qwen3_Next model to pass position_ids to linear attention.""" + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def get_cu_seqlens(position_ids): + """ + Adapted from transformers.modeling_flash_attention_utils.prepare_fa_kwargs_from_position_ids. + + https://github.com/huggingface/transformers/blob/0f1b128d3359a26bd18be99c26d7f04fb3cba914/src/transformers/modeling_flash_attention_utils.py#L316 + """ + tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device} + + position_ids = position_ids.view(-1) + indices_q = (position_ids == 0).nonzero().view(-1) + + cu_seq_lens_q = torch.cat( + ( + indices_q.to(**tensor_kwargs), + torch.tensor(position_ids.size(), **tensor_kwargs), + ) + ) + + return cu_seq_lens_q + + +def patch_qwen3_next_decoder_layer(): + """Patch Qwen3NextDecoderLayer to pass position_ids to linear attention.""" + try: + from transformers.models.qwen3_next.modeling_qwen3_next import ( + Qwen3NextDecoderLayer, + ) + except ImportError: + LOG.warning("Qwen3Next model not found, skipping patch") + return + + # Store original forward method + original_decoder_forward = Qwen3NextDecoderLayer.forward + + def patched_decoder_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Token Mixer + if self.layer_type == "linear_attention": + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=attention_mask, + position_ids=position_ids, + ) + elif self.layer_type == "full_attention": + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + # For the MoE layers, we need to unpack + if isinstance(hidden_states, Tuple): + hidden_states, _ = hidden_states + hidden_states = residual + hidden_states + + return hidden_states + + # Apply the patches + Qwen3NextDecoderLayer.forward = patched_decoder_forward + + def unpatch(): + """Restore the original forward method""" + Qwen3NextDecoderLayer.forward = original_decoder_forward + + return unpatch + + +def patch_qwen3_next_gateddelta_layer(): + """Patch Qwen3NextGatedDeltaNet to parse cu_seqlens and pass to chunk_gated_delta_rule""" + try: + from transformers.models.qwen3_next.modeling_qwen3_next import ( + Qwen3NextDynamicCache, + Qwen3NextGatedDeltaNet, + apply_mask_to_padding_states, + ) + except ImportError: + LOG.warning("Qwen3Next model not found, skipping patch") + return + + # Store original forward method + original_gated_delta_net_forward = Qwen3NextGatedDeltaNet.forward + + def patched_gated_delta_net_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Qwen3NextDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + ): + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + use_precomputed_states = ( + cache_params is not None + and cache_params.has_previous_state + and seq_len == 1 + and cache_position is not None + ) + + # getting projected states from cache if it exists + if cache_params is not None: + conv_state = cache_params.conv_states[self.layer_idx] + recurrent_state = cache_params.recurrent_states[self.layer_idx] + + projected_states_qkvz = self.in_proj_qkvz(hidden_states) + projected_states_ba = self.in_proj_ba(hidden_states) + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + query, key, value = ( + x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value) + ) + + mixed_qkv = torch.cat((query, key, value), dim=-1) + mixed_qkv = mixed_qkv.transpose(1, 2) + + if use_precomputed_states: + # 2. Convolution sequence transformation + # NOTE: the conv state is updated in `causal_conv1d_update` + mixed_qkv = self.causal_conv1d_update( + mixed_qkv, + conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + else: + if cache_params is not None: + conv_state = F.pad( + mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx] = conv_state + if self.causal_conv1d_fn is not None: + mixed_qkv = self.causal_conv1d_fn( + x=mixed_qkv, + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + seq_idx=None, + ) + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + if not use_precomputed_states: + cu_seqlens = get_cu_seqlens(position_ids=position_ids) + core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + cu_seqlens=cu_seqlens, + ) + + else: + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[self.layer_idx] = last_recurrent_state + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape( + core_attn_out.shape[0], core_attn_out.shape[1], -1 + ) + + output = self.out_proj(core_attn_out) + return output + + # Apply the patches + Qwen3NextGatedDeltaNet.forward = patched_gated_delta_net_forward + + def unpatch(): + """Restore the original forward method""" + Qwen3NextGatedDeltaNet.forward = original_gated_delta_net_forward + + return unpatch + + +def patch_qwen3_next_imports(): + """Patch Qwen3Next imports to use try/except instead of is_flash_linear_attention_available.""" + try: + import transformers.models.qwen3_next.modeling_qwen3_next as qwen3_modeling + except ImportError: + LOG.warning("Qwen3Next model not found, skipping import patch") + return + + # Save original values for unpatch + original_FusedRMSNormGated = getattr(qwen3_modeling, "FusedRMSNormGated", None) + original_chunk_gated_delta_rule = getattr( + qwen3_modeling, "chunk_gated_delta_rule", None + ) + original_fused_recurrent_gated_delta_rule = getattr( + qwen3_modeling, "fused_recurrent_gated_delta_rule", None + ) + original_is_fast_path_available = getattr( + qwen3_modeling, "is_fast_path_available", False + ) + + try: + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule, + fused_recurrent_gated_delta_rule, + ) + + qwen3_modeling.FusedRMSNormGated = FusedRMSNormGated + qwen3_modeling.chunk_gated_delta_rule = chunk_gated_delta_rule + qwen3_modeling.fused_recurrent_gated_delta_rule = ( + fused_recurrent_gated_delta_rule + ) + + # Force is_fast_path_available to be True + # fla has triton kernels for causal_conv1d + qwen3_modeling.is_fast_path_available = True + except ImportError: + qwen3_modeling.chunk_gated_delta_rule = None + qwen3_modeling.fused_recurrent_gated_delta_rule = None + qwen3_modeling.FusedRMSNormGated = None + + def unpatch(): + """Restore the original import values""" + qwen3_modeling.FusedRMSNormGated = original_FusedRMSNormGated + qwen3_modeling.chunk_gated_delta_rule = original_chunk_gated_delta_rule + qwen3_modeling.fused_recurrent_gated_delta_rule = ( + original_fused_recurrent_gated_delta_rule + ) + qwen3_modeling.is_fast_path_available = original_is_fast_path_available + + return unpatch + + +def patch_qwen3_next_modeling_packing(): + """Apply all Qwen3Next model patches.""" + patch_qwen3_next_imports() + patch_qwen3_next_decoder_layer() + patch_qwen3_next_gateddelta_layer() + + LOG.info("Applied Qwen3Next patch for packing") diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 726e60111..4741245e1 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -21,6 +21,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [ "qwen2_moe", "qwen3", "qwen3_moe", + "qwen3_next", "falcon", "phi", "phi3", diff --git a/tests/monkeypatch/test_qwen3_next_modeling_patch.py b/tests/monkeypatch/test_qwen3_next_modeling_patch.py new file mode 100644 index 000000000..91d9fc1cf --- /dev/null +++ b/tests/monkeypatch/test_qwen3_next_modeling_patch.py @@ -0,0 +1,111 @@ +"""Integration tests for Qwen3 Next modeling patches.""" + +import pytest +import torch + +# Skip entire module if qwen3_next not available +qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next") + + +class TestQwen3NextModelingPatchIntegration: + """Test Qwen3 Next modeling patch integration.""" + + @pytest.mark.integration + def test_qwen3_next_decoder_layer_patch(self): + """Test that Qwen3Next decoder layer patch can be applied.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_decoder_layer, + ) + + # Store original method + original_forward = qwen3_next.Qwen3NextDecoderLayer.forward + + # Apply patch and get unpatch function + unpatch_fn = patch_qwen3_next_decoder_layer() + + # Verify patch was applied + assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, ( + "decoder layer forward method was not patched" + ) + + # Verify the method is still callable + assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), ( + "Patched method is not callable" + ) + + # Test unpatch function + if unpatch_fn: + unpatch_fn() + assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, ( + "unpatch function did not restore original method" + ) + + @pytest.mark.integration + def test_qwen3_next_gateddelta_layer_patch(self): + """Test that Qwen3Next GatedDeltaNet patch can be applied.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_gateddelta_layer, + ) + + # Store original method + original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward + + # Apply patch and get unpatch function + unpatch_fn = patch_qwen3_next_gateddelta_layer() + + # Verify patch was applied + assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, ( + "GatedDeltaNet forward method was not patched" + ) + + # Verify the method is still callable + assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), ( + "Patched method is not callable" + ) + + # Test unpatch function + if unpatch_fn: + unpatch_fn() + assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, ( + "unpatch function did not restore original method" + ) + + @pytest.mark.integration + def test_qwen3_next_imports_patch(self): + """Test that Qwen3Next imports patch can be applied without errors.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_imports, + ) + + # Apply patch - should not raise any exceptions even if modules unavailable + unpatch_fn = patch_qwen3_next_imports() + + # Test that unpatch function is returned (or None if skipped) + assert unpatch_fn is None or callable(unpatch_fn), ( + "patch_qwen3_next_imports should return None or callable unpatch function" + ) + + @pytest.mark.integration + def test_qwen3_next_modeling_packing_patch(self): + """Test that all Qwen3Next modeling patches can be applied together.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import ( + patch_qwen3_next_modeling_packing, + ) + + # This should not raise any exceptions + patch_qwen3_next_modeling_packing() + + +@pytest.mark.integration +def test_get_cu_seqlens_utility(): + """Test the get_cu_seqlens utility function.""" + from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens + + # Test with simple position_ids + position_ids = torch.tensor([[0, 1, 2, 0, 1]]) + cu_seqlens = get_cu_seqlens(position_ids) + assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype" + + # Should return tensor with start positions and total length + expected = torch.tensor([0, 3, 5], dtype=torch.int32) + assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"