Files
axolotl/tests/e2e/multigpu/test_sequence_parallelism.py
Dan Saunders 51c326150b pytest
2025-03-21 16:35:10 +00:00

115 lines
3.6 KiB
Python

"""Tests for end-to-end sequence parallelism integration."""
import os
import tempfile
import pytest
import torch
import yaml
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
def test_integration_with_config():
"""Test end-to-end training configuration setup for sequence parallelism."""
# Define a test config directly in code instead of loading from file
config_dict = {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "LlamaTokenizer",
"is_llama_derived_model": True,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"load_in_8bit": False,
"sequence_len": 1024,
"sequence_parallel_size": 2,
"flash_attention": True,
"sample_packing": True,
"pad_to_sequence_len": True,
"micro_batch_size": 2,
"num_epochs": 1,
"max_steps": 10,
"gradient_accumulation_steps": 1,
"warmup_steps": 2,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"learning_rate": 2.0e-4,
"weight_decay": 0.0,
"val_set_size": 0.05,
"eval_steps": 5,
"save_steps": 10,
}
# Create a temp dir for output
with tempfile.TemporaryDirectory() as temp_dir:
config_dict["output_dir"] = temp_dir
# Also write to a file for completeness
config_path = os.path.join(temp_dir, "sp_config.yml")
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(config_dict, f)
# Convert to DictDefault and validate
cfg = DictDefault(config_dict)
cfg = validate_config(cfg)
normalize_config(cfg)
# Verify sequence parallelism settings were properly processed
assert cfg.sequence_parallel_size == 2
assert cfg.flash_attention is True
# Check if the sequence_parallel_size was propagated to the training args
from axolotl.core.training_args import AxolotlTrainingArguments
# pylint: disable=unexpected-keyword-arg
training_args = AxolotlTrainingArguments(
output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size
)
assert training_args.sequence_parallel_size == 2
def test_ring_attn_group_creation():
"""Test that ring attention groups are properly created in a multi-GPU environment."""
# First ensure we're in a distributed environment
if not torch.distributed.is_initialized():
# Skip this test if not in distributed mode
pytest.skip(
"This test requires a properly initialized torch.distributed environment"
)
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
register_ring_attn,
)
# Get the current rank and world size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Only run if we have an even number of GPUs
if world_size % 2 != 0:
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
# Register with sequence parallel size of 2
register_ring_attn(sequence_parallel_size=2)
# Get the ring attention group
group = get_ring_attn_group()
# Verify the group exists
assert group is not None
# Calculate expected group members
group_id = rank // 2
expected_start = group_id * 2
expected_group = list(range(expected_start, expected_start + 2))
# Verify our rank is in the expected group
assert rank in expected_group
# Clean up by synchronizing all processes
torch.distributed.barrier()