fix tests
This commit is contained in:
committed by
Dan Saunders
parent
22cfa42961
commit
ab3b36339a
@@ -170,6 +170,7 @@ def test_sequence_parallel_slicing(
|
||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||
|
||||
|
||||
@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()})
|
||||
def test_config_validation_with_valid_inputs(cfg):
|
||||
"""Test that valid sequence parallelism configurations pass validation."""
|
||||
# Import the actual model class with appropriate mocks
|
||||
|
||||
Reference in New Issue
Block a user