111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
"""Tests for end-to-end sequence parallelism integration."""
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
import torch
|
|
import yaml
|
|
|
|
from axolotl.utils.config import normalize_config, validate_config
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
def test_integration_with_config():
|
|
"""Test end-to-end training configuration setup for sequence parallelism."""
|
|
# Define a test config directly in code instead of loading from file
|
|
config_dict = {
|
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
"tokenizer_type": "LlamaTokenizer",
|
|
"is_llama_derived_model": True,
|
|
"datasets": [
|
|
{
|
|
"path": "mhenrichsen/alpaca_2k_test",
|
|
"type": "alpaca",
|
|
}
|
|
],
|
|
"load_in_8bit": False,
|
|
"sequence_len": 1024,
|
|
"sequence_parallel_degree": 2,
|
|
"flash_attention": True,
|
|
"sample_packing": True,
|
|
"pad_to_sequence_len": True,
|
|
"micro_batch_size": 2,
|
|
"num_epochs": 1,
|
|
"max_steps": 10,
|
|
"gradient_accumulation_steps": 1,
|
|
"warmup_steps": 2,
|
|
"optimizer": "adamw_bnb_8bit",
|
|
"lr_scheduler": "cosine",
|
|
"learning_rate": 2.0e-4,
|
|
"weight_decay": 0.0,
|
|
"val_set_size": 0.05,
|
|
"eval_steps": 5,
|
|
"save_steps": 10,
|
|
}
|
|
|
|
# Create a temp dir for output
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
config_dict["output_dir"] = temp_dir
|
|
|
|
# Also write to a file for completeness
|
|
config_path = os.path.join(temp_dir, "sp_config.yml")
|
|
with open(config_path, "w", encoding="utf-8") as f:
|
|
yaml.dump(config_dict, f)
|
|
|
|
# Convert to DictDefault and validate
|
|
cfg = DictDefault(config_dict)
|
|
cfg = validate_config(cfg)
|
|
normalize_config(cfg)
|
|
|
|
# Verify sequence parallelism settings were properly processed
|
|
assert cfg.sequence_parallel_degree == 2
|
|
assert cfg.flash_attention is True
|
|
|
|
# Check if the sequence_parallel_degree was propagated to the training args
|
|
from axolotl.core.training_args import AxolotlTrainingArguments
|
|
|
|
# pylint: disable=unexpected-keyword-arg
|
|
training_args = AxolotlTrainingArguments(
|
|
output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
|
|
)
|
|
assert training_args.sequence_parallel_degree == 2
|
|
|
|
|
|
def test_ring_attn_group_creation():
|
|
"""Test that ring attention groups are properly created in a multi-GPU environment."""
|
|
if not torch.distributed.is_initialized():
|
|
torch.distributed.init_process_group("nccl")
|
|
|
|
from axolotl.monkeypatch.attention.ring_attn import (
|
|
get_ring_attn_group,
|
|
register_ring_attn,
|
|
)
|
|
|
|
# Get the current rank and world size
|
|
rank = torch.distributed.get_rank()
|
|
world_size = torch.distributed.get_world_size()
|
|
|
|
# Only run if we have an even number of GPUs
|
|
if world_size % 2 != 0:
|
|
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
|
|
|
|
# Register with sequence parallel size of 2
|
|
register_ring_attn(sequence_parallel_degree=2)
|
|
|
|
# Get the ring attention group
|
|
group = get_ring_attn_group()
|
|
|
|
# Verify the group exists
|
|
assert group is not None
|
|
|
|
# Calculate expected group members
|
|
group_id = rank // 2
|
|
expected_start = group_id * 2
|
|
expected_group = list(range(expected_start, expected_start + 2))
|
|
|
|
# Verify our rank is in the expected group
|
|
assert rank in expected_group
|
|
|
|
# Clean up by synchronizing all processes
|
|
torch.distributed.barrier()
|