Distributed/ND-Parallel (#2977)
This commit is contained in:
@@ -67,7 +67,7 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
"save_first_step": False,
|
||||
}
|
||||
@@ -105,13 +105,13 @@ class TestSequenceParallelism:
|
||||
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, True, "batch_zigzag", 2.5),
|
||||
(False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
|
||||
@@ -298,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -51,6 +51,7 @@ class TestFP8FSDP2:
|
||||
"""Test class for FP8 mixed precision with FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
@require_hopper
|
||||
def test_fp8_fsdp2_smoke(self, temp_dir):
|
||||
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
|
||||
cfg = DictDefault(
|
||||
|
||||
69
tests/e2e/multigpu/test_tp.py
Normal file
69
tests/e2e/multigpu/test_tp.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""multigpu e2e test for tensor parallelism."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
|
||||
|
||||
|
||||
class TestTensorParallel:
|
||||
"""Test class for Tensor Parallel functionality."""
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="TP doesn't work with models with tied weights (embeddings)"
|
||||
)
|
||||
@require_torch_2_7_0
|
||||
def test_fft_sft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"tensor_parallel_size": 2,
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
Reference in New Issue
Block a user