add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test
This commit is contained in:
@@ -17,47 +17,7 @@ except ImportError:
|
|||||||
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens
|
||||||
def get_cu_seqlens(attn_mask):
|
|
||||||
device = attn_mask.device
|
|
||||||
# Exclude zeros to avoid adding their positions to the mask
|
|
||||||
t_non_zeros = attn_mask[attn_mask != 0]
|
|
||||||
# Find where the sequence number changes (including the first position)
|
|
||||||
seq_change = torch.cat(
|
|
||||||
[
|
|
||||||
torch.tensor([1], dtype=torch.int32, device=device),
|
|
||||||
t_non_zeros[1:] != t_non_zeros[:-1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Get the indices where the sequence changes
|
|
||||||
change_indices = torch.cat(
|
|
||||||
[
|
|
||||||
(seq_change == 1).nonzero(as_tuple=True)[0],
|
|
||||||
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the sequence lengths
|
|
||||||
seq_lengths = change_indices[1:] - change_indices[:-1]
|
|
||||||
# Calculate the length of the final sequence or padding
|
|
||||||
final_seq_length = attn_mask.shape[1] - change_indices[-1]
|
|
||||||
# Append the length of the final sequence or padding to seq_lengths
|
|
||||||
if final_seq_length.item():
|
|
||||||
seq_lengths = torch.cat(
|
|
||||||
[
|
|
||||||
seq_lengths,
|
|
||||||
torch.tensor(
|
|
||||||
[final_seq_length.item()], dtype=torch.int32, device=device
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Calculate the cumulative sequence lengths
|
|
||||||
cu_seqlens = torch.cat(
|
|
||||||
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
|
||||||
)
|
|
||||||
|
|
||||||
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
||||||
|
|
||||||
return cu_seqlens.to(dtype=torch.int32), max_seq_len
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
103
src/axolotl/monkeypatch/utils.py
Normal file
103
src/axolotl/monkeypatch/utils.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Shared utils for the monkeypatches
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_seqlens(attn_mask):
|
||||||
|
"""generate a cumulative sequence length mask for flash attention using attn mask"""
|
||||||
|
if len(attn_mask.shape) == 1:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
device = attn_mask.device
|
||||||
|
results = []
|
||||||
|
max_seq_lens = []
|
||||||
|
|
||||||
|
for row in attn_mask:
|
||||||
|
# Exclude zeros to avoid adding their positions to the mask
|
||||||
|
t_non_zeros = row[row != 0]
|
||||||
|
# Find where the sequence number changes (including the first position)
|
||||||
|
seq_change = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([1], dtype=torch.int32, device=device),
|
||||||
|
t_non_zeros[1:] != t_non_zeros[:-1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence changes
|
||||||
|
change_indices = torch.cat(
|
||||||
|
[
|
||||||
|
(seq_change == 1).nonzero(as_tuple=True)[0],
|
||||||
|
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = change_indices[1:] - change_indices[:-1]
|
||||||
|
# Calculate the length of the final sequence or padding
|
||||||
|
final_seq_length = len(row) - change_indices[-1]
|
||||||
|
# Append the length of the final sequence or padding to seq_lengths
|
||||||
|
if final_seq_length.item():
|
||||||
|
seq_lengths = torch.cat(
|
||||||
|
[
|
||||||
|
seq_lengths,
|
||||||
|
torch.tensor(
|
||||||
|
[final_seq_length.item()], dtype=torch.int32, device=device
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the cumulative sequence lengths
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||||
|
)
|
||||||
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
|
results.append(cu_seqlens)
|
||||||
|
max_seq_lens.append(max_seq_len)
|
||||||
|
|
||||||
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
|
if len(position_ids.shape) == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
device = position_ids.device
|
||||||
|
results = []
|
||||||
|
max_seq_lens = []
|
||||||
|
|
||||||
|
for row in position_ids:
|
||||||
|
# Count the number of consecutive zeros from the right side
|
||||||
|
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||||
|
|
||||||
|
# Adjust the row to exclude padding
|
||||||
|
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||||
|
|
||||||
|
# Find where the position resets to 0 (indicating a new sequence)
|
||||||
|
seq_starts = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([True], dtype=torch.bool, device=device),
|
||||||
|
adjusted_row[1:] == 0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence starts
|
||||||
|
start_indices = torch.cat(
|
||||||
|
[
|
||||||
|
(seq_starts).nonzero(as_tuple=True)[0],
|
||||||
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||||
|
# Calculate the cumulative sequence lengths
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||||
|
)
|
||||||
|
# Append the padding length to the cumulative sequence lengths
|
||||||
|
if padding_length:
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
|
||||||
|
)
|
||||||
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
|
results.append(cu_seqlens)
|
||||||
|
max_seq_lens.append(max_seq_len)
|
||||||
|
|
||||||
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
@@ -45,7 +45,7 @@ class AlpacaPrompter:
|
|||||||
if self.prompt_style == PromptStyle.CHAT.value:
|
if self.prompt_style == PromptStyle.CHAT.value:
|
||||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||||
self.system_format = "SYSTEM:{system}\n"
|
self.system_format = "SYSTEM: {system}\n"
|
||||||
if self.prompt_style == PromptStyle.CHATML.value:
|
if self.prompt_style == PromptStyle.CHATML.value:
|
||||||
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
self.turn_no_input_format = (
|
self.turn_no_input_format = (
|
||||||
|
|||||||
30
tests/monkeypatch/test_llama_attn_hijack_flash.py
Normal file
30
tests/monkeypatch/test_llama_attn_hijack_flash.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the monkeypatch utils
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestMonkeyPatchUtils(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Unit test class for monkeypatch utils
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_get_cu_seqlens_1d(self):
|
||||||
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||||
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||||
|
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
|
||||||
|
|
||||||
|
def test_get_cu_seqlens_from_pos_ids_1d(self):
|
||||||
|
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
|
||||||
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|||||||
"output": "Hi! How can I help?",
|
"output": "Hi! How can I help?",
|
||||||
}
|
}
|
||||||
example = strat.tokenize_prompt(sample)
|
example = strat.tokenize_prompt(sample)
|
||||||
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
|
assert example["input_ids"][0:5] == [
|
||||||
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
|
1,
|
||||||
assert example["input_ids"][9] == 11889 # USER
|
28962,
|
||||||
|
1254,
|
||||||
|
12665,
|
||||||
|
29901,
|
||||||
|
] # "<s>SYSTEM:"
|
||||||
|
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
|
||||||
|
assert example["input_ids"][8] == 11889 # USER
|
||||||
|
|
||||||
|
|
||||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user