Multipack simplify for Mixtral (#1142)

This commit is contained in:
Wing Lian
2024-01-18 16:23:49 -05:00
committed by GitHub
parent 1d70f24b50
commit 6910e6a8ca
11 changed files with 201 additions and 430 deletions

View File

@@ -7,8 +7,6 @@ import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -60,12 +58,9 @@ class TestMixtral(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
"bf16": "auto",
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -101,23 +96,16 @@ class TestMixtral(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"sample_packing": True,
"bf16": "auto",
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (
"axolotl.monkeypatch.mixtral.modeling_mixtral"
in model.model.layers[0].self_attn.__class__.__module__
)
assert (
"MixtralMultipackFlashAttention2"
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -52,11 +52,7 @@ class TestModelPatches(unittest.TestCase):
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
assert (
"axolotl.monkeypatch.mixtral.modeling_mixtral"
in model.model.layers[0].self_attn.__class__.__module__
)
assert (
"MixtralMultipackFlashAttention2"
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)

View File

@@ -5,7 +5,12 @@ import unittest
import torch
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.utils import (
get_cu_seqlens,
get_cu_seqlens_from_pos_ids,
get_max_seqlen_in_batch,
get_unpad_data,
)
class TestMonkeyPatchUtils(unittest.TestCase):
@@ -25,6 +30,70 @@ class TestMonkeyPatchUtils(unittest.TestCase):
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_max_seqlen_in_batch(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
def test_get_unpad_data(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
target_max_seqlen_in_batch = 5
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
self.assertTrue(torch.allclose(target_indices, indices))
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
attn_mask = torch.tensor(
[
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
]
)
target_indices = torch.tensor(
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
]
)
target_cu_seqlen = torch.tensor(
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
)
target_max_seqlen_in_batch = 5
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
self.assertTrue(torch.allclose(target_indices, indices))
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
if __name__ == "__main__":
unittest.main()