rename file, delete another
This commit is contained in:
@@ -1,110 +0,0 @@
|
||||
"""Tests for end-to-end sequence parallelism integration."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def test_integration_with_config():
|
||||
"""Test end-to-end training configuration setup for sequence parallelism."""
|
||||
# Define a test config directly in code instead of loading from file
|
||||
config_dict = {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"is_llama_derived_model": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"load_in_8bit": False,
|
||||
"sequence_len": 1024,
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"micro_batch_size": 2,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 10,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"warmup_steps": 2,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 2.0e-4,
|
||||
"weight_decay": 0.0,
|
||||
"val_set_size": 0.05,
|
||||
"eval_steps": 5,
|
||||
"save_steps": 10,
|
||||
}
|
||||
|
||||
# Create a temp dir for output
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config_dict["output_dir"] = temp_dir
|
||||
|
||||
# Also write to a file for completeness
|
||||
config_path = os.path.join(temp_dir, "sp_config.yml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(config_dict, f)
|
||||
|
||||
# Convert to DictDefault and validate
|
||||
cfg = DictDefault(config_dict)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
# Verify sequence parallelism settings were properly processed
|
||||
assert cfg.sequence_parallel_degree == 2
|
||||
assert cfg.flash_attention is True
|
||||
|
||||
# Check if the sequence_parallel_degree was propagated to the training args
|
||||
from axolotl.core.training_args import AxolotlTrainingArguments
|
||||
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
training_args = AxolotlTrainingArguments(
|
||||
output_dir=temp_dir, sequence_parallel_degree=cfg.sequence_parallel_degree
|
||||
)
|
||||
assert training_args.sequence_parallel_degree == 2
|
||||
|
||||
|
||||
def test_ring_attn_group_creation():
|
||||
"""Test that ring attention groups are properly created in a multi-GPU environment."""
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
)
|
||||
|
||||
# Get the current rank and world size
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# Only run if we have an even number of GPUs
|
||||
if world_size % 2 != 0:
|
||||
pytest.skip(f"Need an even number of GPUs, but got {world_size}")
|
||||
|
||||
# Register with sequence parallel size of 2
|
||||
register_ring_attn(sequence_parallel_degree=2)
|
||||
|
||||
# Get the ring attention group
|
||||
group = get_ring_attn_group()
|
||||
|
||||
# Verify the group exists
|
||||
assert group is not None
|
||||
|
||||
# Calculate expected group members
|
||||
group_id = rank // 2
|
||||
expected_start = group_id * 2
|
||||
expected_group = list(range(expected_start, expected_start + 2))
|
||||
|
||||
# Verify our rank is in the expected group
|
||||
assert rank in expected_group
|
||||
|
||||
# Clean up by synchronizing all processes
|
||||
torch.distributed.barrier()
|
||||
Reference in New Issue
Block a user