batch api HF adapter for ring-flash-attn; cleanup and improvements (#2520)

* batch api HF adapter for ring-flash-attn; cleanup and improvements

* update

* adding all batch ring-flash-attn methods via single adapter

* removing pad_to_sequence_len=False for now

* fix

* updating docs to include batch SP

* review comments

* fixes for batch API funcs, simplify

* fixes

* fix

* updates

* add batch_zigzag smoke test
This commit is contained in:
Dan Saunders
2025-04-16 13:50:48 -04:00
committed by GitHub
parent 682a9cf79b
commit b8c633aa97
13 changed files with 397 additions and 49 deletions

View File

@@ -3,6 +3,7 @@
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
@@ -17,8 +18,15 @@ os.environ["WANDB_DISABLED"] = "true"
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
def test_sequence_parallel_training(self, temp_dir):
# pylint: disable=duplicate-code
def _run_sequence_parallel_test(
self,
temp_dir,
sample_packing=True,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -27,9 +35,9 @@ class TestSequenceParallelism:
"strict": False,
"sequence_len": 2048,
"adapter": "qlora",
"sample_packing": True,
"eval_sample_packing": True,
"pad_to_sequence_len": True,
"sample_packing": sample_packing,
"eval_sample_packing": sample_packing,
"pad_to_sequence_len": pad_to_sequence_len,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
@@ -45,7 +53,7 @@ class TestSequenceParallelism:
],
"num_epochs": 1,
"max_steps": 8,
"micro_batch_size": 1,
"micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
@@ -61,6 +69,7 @@ class TestSequenceParallelism:
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"ring_attn_func": ring_attn_func,
}
)
@@ -86,3 +95,35 @@ class TestSequenceParallelism:
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.6, "Train Loss is too high"
)
@pytest.mark.parametrize(
"sample_packing, micro_batch_size, pad_to_sequence_len, ring_attn_func",
[
(True, 1, True, None), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None), # defaults to batch_ring ring_attn_func
(False, 2, True, "batch_zigzag"),
# (False, 2, False), # not yet working
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
# "no sample_packing, pad_to_sequence_len", # not yet working
],
)
def test_sequence_parallel_training(
self,
temp_dir,
sample_packing,
micro_batch_size,
pad_to_sequence_len,
ring_attn_func,
):
"""Test sequence parallel training with different configurations"""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=sample_packing,
micro_batch_size=micro_batch_size,
pad_to_sequence_len=pad_to_sequence_len,
ring_attn_func=ring_attn_func,
)

View File

@@ -73,7 +73,10 @@ class TestRingAttention:
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
from axolotl.monkeypatch.attention.ring_attn import (
RingAttnFunc,
register_ring_attn,
)
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
@@ -82,7 +85,11 @@ class TestRingAttention:
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(sequence_parallel_degree=4, heads_k_stride=1)
register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2