add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test
This commit is contained in:
30
tests/monkeypatch/test_llama_attn_hijack_flash.py
Normal file
30
tests/monkeypatch/test_llama_attn_hijack_flash.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
Unit tests for the monkeypatch utils
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
|
||||
|
||||
|
||||
class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for monkeypatch utils
|
||||
"""
|
||||
|
||||
def test_get_cu_seqlens_1d(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
|
||||
|
||||
def test_get_cu_seqlens_from_pos_ids_1d(self):
|
||||
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
|
||||
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||
self.assertTrue(
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
||||
"output": "Hi! How can I help?",
|
||||
}
|
||||
example = strat.tokenize_prompt(sample)
|
||||
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
|
||||
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
|
||||
assert example["input_ids"][9] == 11889 # USER
|
||||
assert example["input_ids"][0:5] == [
|
||||
1,
|
||||
28962,
|
||||
1254,
|
||||
12665,
|
||||
29901,
|
||||
] # "<s>SYSTEM:"
|
||||
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
|
||||
assert example["input_ids"][8] == 11889 # USER
|
||||
|
||||
|
||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user