diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 58251b5e3..a20ad9ff2 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -175,7 +175,7 @@ def test_sequence_parallel_slicing( def test_config_validation_with_valid_inputs(cfg): """Test that valid sequence parallelism configurations pass validation.""" # Import the actual model class with appropriate mocks - from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig + from axolotl.utils.schemas.config import AxolotlInputConfig # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True cfg = cfg | { @@ -191,7 +191,7 @@ def test_config_validation_with_valid_inputs(cfg): def test_config_validation_with_invalid_inputs(cfg): """Test that invalid sequence parallelism configurations fail validation.""" - from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig + from axolotl.utils.schemas.config import AxolotlInputConfig # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False cfg = cfg | {