fix tests

This commit is contained in:
Dan Saunders
2025-03-20 12:04:22 -04:00
committed by Dan Saunders
parent 22cfa42961
commit ab3b36339a
4 changed files with 6 additions and 11 deletions

View File

@@ -170,6 +170,7 @@ def test_sequence_parallel_slicing(
assert torch.all(result["input_ids"] == expected_input_ids)
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
def test_config_validation_with_valid_inputs(cfg):
"""Test that valid sequence parallelism configurations pass validation."""
# Import the actual model class with appropriate mocks