* feat: upgrade cce for qwen3-next * feat: add sample qwen3 config * feat: add packing patch for chunk_gated_delta_rule * feat: add qwen3 link * fix: tuple name * feat: add tested qwen3 config * fix: improve log * feat: add patch for fla without packing * fix: remove fla patch for standard mode * feat: enable packing * feat: add qwen3-next tests * chore: move tests
112 lines
4.0 KiB
Python
112 lines
4.0 KiB
Python
"""Integration tests for Qwen3 Next modeling patches."""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
# Skip entire module if qwen3_next not available
|
|
qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next")
|
|
|
|
|
|
class TestQwen3NextModelingPatchIntegration:
|
|
"""Test Qwen3 Next modeling patch integration."""
|
|
|
|
@pytest.mark.integration
|
|
def test_qwen3_next_decoder_layer_patch(self):
|
|
"""Test that Qwen3Next decoder layer patch can be applied."""
|
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
|
patch_qwen3_next_decoder_layer,
|
|
)
|
|
|
|
# Store original method
|
|
original_forward = qwen3_next.Qwen3NextDecoderLayer.forward
|
|
|
|
# Apply patch and get unpatch function
|
|
unpatch_fn = patch_qwen3_next_decoder_layer()
|
|
|
|
# Verify patch was applied
|
|
assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (
|
|
"decoder layer forward method was not patched"
|
|
)
|
|
|
|
# Verify the method is still callable
|
|
assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (
|
|
"Patched method is not callable"
|
|
)
|
|
|
|
# Test unpatch function
|
|
if unpatch_fn:
|
|
unpatch_fn()
|
|
assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (
|
|
"unpatch function did not restore original method"
|
|
)
|
|
|
|
@pytest.mark.integration
|
|
def test_qwen3_next_gateddelta_layer_patch(self):
|
|
"""Test that Qwen3Next GatedDeltaNet patch can be applied."""
|
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
|
patch_qwen3_next_gateddelta_layer,
|
|
)
|
|
|
|
# Store original method
|
|
original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward
|
|
|
|
# Apply patch and get unpatch function
|
|
unpatch_fn = patch_qwen3_next_gateddelta_layer()
|
|
|
|
# Verify patch was applied
|
|
assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (
|
|
"GatedDeltaNet forward method was not patched"
|
|
)
|
|
|
|
# Verify the method is still callable
|
|
assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (
|
|
"Patched method is not callable"
|
|
)
|
|
|
|
# Test unpatch function
|
|
if unpatch_fn:
|
|
unpatch_fn()
|
|
assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (
|
|
"unpatch function did not restore original method"
|
|
)
|
|
|
|
@pytest.mark.integration
|
|
def test_qwen3_next_imports_patch(self):
|
|
"""Test that Qwen3Next imports patch can be applied without errors."""
|
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
|
patch_qwen3_next_imports,
|
|
)
|
|
|
|
# Apply patch - should not raise any exceptions even if modules unavailable
|
|
unpatch_fn = patch_qwen3_next_imports()
|
|
|
|
# Test that unpatch function is returned (or None if skipped)
|
|
assert unpatch_fn is None or callable(unpatch_fn), (
|
|
"patch_qwen3_next_imports should return None or callable unpatch function"
|
|
)
|
|
|
|
@pytest.mark.integration
|
|
def test_qwen3_next_modeling_packing_patch(self):
|
|
"""Test that all Qwen3Next modeling patches can be applied together."""
|
|
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
|
patch_qwen3_next_modeling_packing,
|
|
)
|
|
|
|
# This should not raise any exceptions
|
|
patch_qwen3_next_modeling_packing()
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_get_cu_seqlens_utility():
|
|
"""Test the get_cu_seqlens utility function."""
|
|
from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens
|
|
|
|
# Test with simple position_ids
|
|
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
|
|
cu_seqlens = get_cu_seqlens(position_ids)
|
|
assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype"
|
|
|
|
# Should return tensor with start positions and total length
|
|
expected = torch.tensor([0, 3, 5], dtype=torch.int32)
|
|
assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"
|