support for true batches with multipack (#1230)

* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
Wing Lian
2024-02-01 10:18:42 -05:00
committed by GitHub
parent c67fb71583
commit 00568c1539
24 changed files with 573 additions and 246 deletions

View File

@@ -0,0 +1,114 @@
"""
E2E tests for multipack fft llama using 4d attention masks
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_1_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class Test4dMultipackLlama(unittest.TestCase):
"""
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_1_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": True,
"sample_packing": True,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": False,
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -33,6 +33,7 @@ class TestFusedLlama(unittest.TestCase):
{
"base_model": "JackFram/llama-68m",
"flash_attention": True,
"pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,

View File

@@ -4,7 +4,9 @@ helper utils for tests
import os
import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
@@ -31,3 +33,15 @@ def most_recent_subdir(path):
subdir = max(subdirectories, key=os.path.getctime)
return subdir
def require_torch_2_1_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
"""
def is_min_2_1_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)

View File

@@ -30,6 +30,20 @@ class TestMonkeyPatchUtils(unittest.TestCase):
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_cu_seqlens_from_pos_ids_2d(self):
position_ids = torch.tensor(
[
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
]
)
target_res = torch.tensor(
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_max_seqlen_in_batch(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)

View File

@@ -0,0 +1,99 @@
"""Module for testing streaming dataset sequence packing"""
import pytest
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking:
"""
Test class for packing streaming dataset sequences
"""
@pytest.mark.parametrize(
"batch_size, num_workers",
[
(1, 0),
(2, 0),
(1, 2),
(2, 2),
],
)
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset(
"Trelis/tiny-shakespeare",
split="train",
)
cfg = DictDefault(
{
"train_on_inputs": True,
"sequence_len": max_seq_length,
}
)
ds_cfg = DictDefault(
{
"field": "Text",
}
)
completion_strategy = load(tokenizer, cfg, ds_cfg)
dataset_wrapper = TokenizedPromptDataset(
completion_strategy,
dataset,
)
train_dataset = concatenate_datasets([dataset_wrapper])
batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=max_seq_length,
lengths=get_dataset_lengths(train_dataset),
)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
tokenizer=tokenizer,
padding=True,
pad_to_multiple_of=max_seq_length,
return_tensors="pt",
),
num_workers=num_workers,
)
inputs = next(iter(loader))
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
assert inputs["labels"].shape == (batch_size, max_seq_length)
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
assert inputs["input_ids"].tolist()[0][0] == 2
assert inputs["labels"].tolist()[0][0] == -100
assert inputs["attention_mask"].tolist()[0][0] == 0
assert inputs["attention_mask"].tolist()[0][-1] > 1
if batch_size >= 2:
assert inputs["input_ids"].tolist()[1][0] == 2
assert inputs["labels"].tolist()[1][0] == -100
assert inputs["attention_mask"].tolist()[1][0] == 0
assert inputs["attention_mask"].tolist()[1][-1] > 1

View File

@@ -11,7 +11,7 @@ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Se
from axolotl.utils.data import encode_packed_pretraining
class TestPacking(unittest.TestCase):
class TestPretrainingPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""