diff --git a/tests/e2e/multigpu/test_sequence_parallelism.py b/tests/e2e/multigpu/test_sequence_parallelism.py deleted file mode 100644 index 19edfb0ad..000000000 --- a/tests/e2e/multigpu/test_sequence_parallelism.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Tests for end-to-end sequence parallelism integration.""" -import os -import tempfile - -import pytest -import torch -import yaml - -from axolotl.utils.config import normalize_config, validate_config -from axolotl.utils.dict import DictDefault - - -def test_integration_with_config(): - """Test end-to-end training configuration setup for sequence parallelism.""" - # Define a test config directly in code instead of loading from file - config_dict = { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "tokenizer_type": "LlamaTokenizer", - "is_llama_derived_model": True, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - } - ], - "load_in_8bit": False, - "sequence_len": 1024, - "sequence_parallel_degree": 2, - "flash_attention": True, - "sample_packing": True, - "pad_to_sequence_len": True, - "micro_batch_size": 2, - "num_epochs": 1, - "max_steps": 10, - "gradient_accumulation_steps": 1, - "warmup_steps": 2, - "optimizer": "adamw_bnb_8bit", - "lr_scheduler": "cosine", - "learning_rate": 2.0e-4, - "weight_decay": 0.0, - "val_set_size": 0.05, - "eval_steps": 5, - "save_steps": 10, - } - - # Create a temp dir for output - with tempfile.TemporaryDirectory() as temp_dir: - config_dict["output_dir"] = temp_dir - - # Also write to a file for completeness - config_path = os.path.join(temp_dir, "sp_config.yml") - with open(config_path, "w", encoding="utf-8") as f: - yaml.dump(config_dict, f) - - # Convert to DictDefault and validate - cfg = DictDefault(config_dict) - cfg = validate_config(cfg) - normalize_config(cfg) - - # Verify sequence parallelism settings were properly processed - assert cfg.sequence_parallel_degree == 2 - assert cfg.flash_attention is True - - # Check if the sequence_parallel_degree was propagated to the training args - from axolotl.core.training_args import AxolotlTrainingArguments - - # pylint: disable=unexpected-keyword-arg - training_args = AxolotlTrainingArguments( - output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree - ) - assert training_args.sequence_parallel_degree == 2 - - -def test_ring_attn_group_creation(): - """Test that ring attention groups are properly created in a multi-GPU environment.""" - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") - - from axolotl.monkeypatch.attention.ring_attn import ( - get_ring_attn_group, - register_ring_attn, - ) - - # Get the current rank and world size - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - # Only run if we have an even number of GPUs - if world_size % 2 != 0: - pytest.skip(f"Need an even number of GPUs, but got {world_size}") - - # Register with sequence parallel size of 2 - register_ring_attn(sequence_parallel_degree=2) - - # Get the ring attention group - group = get_ring_attn_group() - - # Verify the group exists - assert group is not None - - # Calculate expected group members - group_id = rank // 2 - expected_start = group_id * 2 - expected_group = list(range(expected_start, expected_start + 2)) - - # Verify our rank is in the expected group - assert rank in expected_group - - # Clean up by synchronizing all processes - torch.distributed.barrier() diff --git a/tests/e2e/patched/test_sequence_parallelism.py b/tests/e2e/patched/test_sp.py similarity index 100% rename from tests/e2e/patched/test_sequence_parallelism.py rename to tests/e2e/patched/test_sp.py