Docstrings generation was requested by @mhenrichsen. * https://github.com/axolotl-ai-cloud/axolotl/pull/2662#issuecomment-2883401776 The following files were modified: * `src/axolotl/utils/data/pretraining.py` * `src/axolotl/utils/data/rl.py` * `src/axolotl/utils/data/utils.py` * `src/axolotl/utils/trainer.py` * `tests/test_data.py` * `tests/test_trainer_utils.py`
176 lines
6.7 KiB
Python
176 lines
6.7 KiB
Python
"""Module containing tests for trainer utility functions."""
|
|
|
|
import unittest
|
|
from functools import partial
|
|
|
|
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
|
|
|
|
|
# Test cases for truncate_or_drop_long_seq
|
|
class TestTruncateOrDropLongSeq(unittest.TestCase):
|
|
"""
|
|
Test suite for truncate_or_drop_long_seq function.
|
|
"""
|
|
|
|
def setUp(self):
|
|
# Example sequence length settings
|
|
"""
|
|
Sets up default sequence length parameters for the test cases.
|
|
"""
|
|
self.sequence_len = 10
|
|
self.min_sequence_len = 3
|
|
|
|
def test_drop_mode_single(self):
|
|
"""
|
|
Verifies that 'drop' mode correctly filters single sequence examples based on length.
|
|
|
|
Tests that sequences shorter than the minimum, longer than the maximum, or empty are dropped,
|
|
while sequences within the valid length range are kept.
|
|
"""
|
|
handler = partial(
|
|
truncate_or_drop_long_seq,
|
|
sequence_len=self.sequence_len,
|
|
min_sequence_len=self.min_sequence_len,
|
|
handling="drop",
|
|
)
|
|
|
|
# Too short
|
|
sample_short = {"input_ids": [1, 2]}
|
|
self.assertFalse(handler(sample_short))
|
|
|
|
# Too long
|
|
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
|
|
self.assertFalse(handler(sample_long))
|
|
|
|
# Just right
|
|
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
|
|
self.assertTrue(handler(sample_ok))
|
|
|
|
# Empty
|
|
sample_empty = {"input_ids": []}
|
|
self.assertFalse(handler(sample_empty))
|
|
|
|
def test_truncate_mode_single(self):
|
|
"""
|
|
Tests that 'truncate_or_drop_long_seq' correctly truncates or preserves single examples in "truncate" mode.
|
|
|
|
Verifies that sequences longer than the maximum length are truncated, while sequences that are too short, empty, or within the valid range remain unchanged.
|
|
"""
|
|
handler = partial(
|
|
truncate_or_drop_long_seq,
|
|
sequence_len=self.sequence_len,
|
|
min_sequence_len=self.min_sequence_len,
|
|
handling="truncate",
|
|
)
|
|
|
|
# Too short (should still be dropped implicitly by filter/map logic upstream,
|
|
# but the function itself might return the sample or False based on impl.)
|
|
# Current impl returns the original sample for map if too short, assuming upstream filters.
|
|
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
|
|
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
|
|
result_short = handler(sample_short)
|
|
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
|
|
|
|
# Too long
|
|
original_long = list(range(self.sequence_len + 5))
|
|
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
|
|
result_long = handler(sample_long)
|
|
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
|
|
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
|
|
self.assertEqual(len(result_long["labels"]), self.sequence_len)
|
|
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
|
|
|
|
# Just right
|
|
sample_ok = {
|
|
"input_ids": list(range(self.min_sequence_len)),
|
|
"labels": list(range(self.min_sequence_len)),
|
|
}
|
|
result_ok = handler(sample_ok)
|
|
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
|
|
self.assertEqual(result_ok, sample_ok) # Should be unchanged
|
|
|
|
# Empty
|
|
sample_empty = {"input_ids": [], "labels": []}
|
|
result_empty = handler(sample_empty)
|
|
self.assertEqual(result_empty, sample_empty) # Unchanged
|
|
|
|
def test_drop_mode_batched(self):
|
|
"""
|
|
Tests that the "drop" handling mode correctly filters batched input sequences based on length constraints.
|
|
|
|
Verifies that sequences shorter than the minimum length, longer than the maximum length, or empty are dropped (returns False), while sequences within the valid range are kept (returns True).
|
|
"""
|
|
handler = partial(
|
|
truncate_or_drop_long_seq,
|
|
sequence_len=self.sequence_len,
|
|
min_sequence_len=self.min_sequence_len,
|
|
handling="drop",
|
|
)
|
|
sample = {
|
|
"input_ids": [
|
|
[1, 2], # Too short
|
|
list(range(self.sequence_len + 1)), # Too long
|
|
list(range(self.sequence_len)), # OK (len = 10)
|
|
list(range(self.min_sequence_len)), # OK (len = 3)
|
|
[], # Empty
|
|
]
|
|
}
|
|
expected = [False, False, True, True, False]
|
|
self.assertEqual(handler(sample), expected)
|
|
|
|
def test_truncate_mode_batched(self):
|
|
"""
|
|
Tests that batched examples are correctly truncated in "truncate" mode.
|
|
|
|
Verifies that sequences in both "input_ids" and "labels" longer than the maximum
|
|
allowed length are truncated, while sequences that are too short or empty remain
|
|
unchanged.
|
|
"""
|
|
handler = partial(
|
|
truncate_or_drop_long_seq,
|
|
sequence_len=self.sequence_len,
|
|
min_sequence_len=self.min_sequence_len,
|
|
handling="truncate",
|
|
)
|
|
sample = {
|
|
"input_ids": [
|
|
[1, 2], # Too short
|
|
list(range(self.sequence_len + 5)), # Too long
|
|
list(range(self.sequence_len)), # OK
|
|
list(range(self.min_sequence_len)), # OK
|
|
[], # Empty
|
|
],
|
|
"labels": [ # Add labels to test truncation
|
|
[1, 2],
|
|
list(range(self.sequence_len + 5)),
|
|
list(range(self.sequence_len)),
|
|
list(range(self.min_sequence_len)),
|
|
[],
|
|
],
|
|
}
|
|
|
|
result = handler(sample)
|
|
|
|
# Expected results after truncation (too short and empty remain unchanged by this function)
|
|
expected_input_ids = [
|
|
[1, 2], # Unchanged (too short)
|
|
list(range(self.sequence_len)), # Truncated
|
|
list(range(self.sequence_len)), # Unchanged (OK)
|
|
list(range(self.min_sequence_len)), # Unchanged (OK)
|
|
[], # Unchanged (Empty)
|
|
]
|
|
expected_labels = [
|
|
[1, 2], # Unchanged (too short)
|
|
list(range(self.sequence_len)), # Truncated
|
|
list(range(self.sequence_len)), # Unchanged (OK)
|
|
list(range(self.min_sequence_len)), # Unchanged (OK)
|
|
[], # Unchanged (Empty)
|
|
]
|
|
|
|
self.assertEqual(result["input_ids"], expected_input_ids)
|
|
self.assertEqual(result["labels"], expected_labels)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|