diff --git a/tests/e2e/multigpu/test_sequence_parallelism.py b/tests/e2e/multigpu/test_sequence_parallelism.py index 9619cf690..c57c76caf 100644 --- a/tests/e2e/multigpu/test_sequence_parallelism.py +++ b/tests/e2e/multigpu/test_sequence_parallelism.py @@ -73,12 +73,8 @@ def test_integration_with_config(): def test_ring_attn_group_creation(): """Test that ring attention groups are properly created in a multi-GPU environment.""" - # First ensure we're in a distributed environment if not torch.distributed.is_initialized(): - # Skip this test if not in distributed mode - pytest.skip( - "This test requires a properly initialized torch.distributed environment" - ) + torch.distributed.init_process_group("nccl") from axolotl.monkeypatch.attention.ring_attn import ( get_ring_attn_group,