update
This commit is contained in:
@@ -73,12 +73,8 @@ def test_integration_with_config():
|
|||||||
|
|
||||||
def test_ring_attn_group_creation():
|
def test_ring_attn_group_creation():
|
||||||
"""Test that ring attention groups are properly created in a multi-GPU environment."""
|
"""Test that ring attention groups are properly created in a multi-GPU environment."""
|
||||||
# First ensure we're in a distributed environment
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Skip this test if not in distributed mode
|
torch.distributed.init_process_group("nccl")
|
||||||
pytest.skip(
|
|
||||||
"This test requires a properly initialized torch.distributed environment"
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
|
|||||||
Reference in New Issue
Block a user