Fixes comments from winglian
This commit is contained in:
155
tests/test_trainer_utils.py
Normal file
155
tests/test_trainer_utils.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
|
||||
# Assuming the function is in axolotl.utils.trainer
|
||||
from axolotl.utils.trainer import truncate_or_drop_long_seq
|
||||
|
||||
|
||||
# Test cases for truncate_or_drop_long_seq
|
||||
class TestTruncateOrDropLongSeq(unittest.TestCase):
|
||||
"""
|
||||
Test suite for truncate_or_drop_long_seq function.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
# Example sequence length settings
|
||||
self.sequence_len = 10
|
||||
self.min_sequence_len = 3
|
||||
|
||||
def test_drop_mode_single(self):
|
||||
"""Test drop mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
|
||||
# Too short
|
||||
sample_short = {"input_ids": [1, 2]}
|
||||
self.assertFalse(handler(sample_short))
|
||||
|
||||
# Too long
|
||||
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
|
||||
self.assertFalse(handler(sample_long))
|
||||
|
||||
# Just right
|
||||
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
|
||||
self.assertTrue(handler(sample_ok))
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": []}
|
||||
self.assertFalse(handler(sample_empty))
|
||||
|
||||
def test_truncate_mode_single(self):
|
||||
"""Test truncate mode with single examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
|
||||
# Too short (should still be dropped implicitly by filter/map logic upstream,
|
||||
# but the function itself might return the sample or False based on impl.)
|
||||
# Current impl returns the original sample for map if too short, assuming upstream filters.
|
||||
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
|
||||
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
|
||||
result_short = handler(sample_short)
|
||||
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
|
||||
|
||||
# Too long
|
||||
original_long = list(range(self.sequence_len + 5))
|
||||
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
|
||||
result_long = handler(sample_long)
|
||||
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
|
||||
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
|
||||
self.assertEqual(len(result_long["labels"]), self.sequence_len)
|
||||
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
|
||||
|
||||
|
||||
# Just right
|
||||
sample_ok = {"input_ids": list(range(self.min_sequence_len)), "labels": list(range(self.min_sequence_len))}
|
||||
result_ok = handler(sample_ok)
|
||||
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
|
||||
self.assertEqual(result_ok, sample_ok) # Should be unchanged
|
||||
|
||||
# Empty
|
||||
sample_empty = {"input_ids": [], "labels": []}
|
||||
result_empty = handler(sample_empty)
|
||||
self.assertEqual(result_empty, sample_empty) # Unchanged
|
||||
|
||||
|
||||
def test_drop_mode_batched(self):
|
||||
"""Test drop mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="drop",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 1)), # Too long
|
||||
list(range(self.sequence_len)), # OK (len = 10)
|
||||
list(range(self.min_sequence_len)), # OK (len = 3)
|
||||
[], # Empty
|
||||
]
|
||||
}
|
||||
expected = [False, False, True, True, False]
|
||||
self.assertEqual(handler(sample), expected)
|
||||
|
||||
|
||||
def test_truncate_mode_batched(self):
|
||||
"""Test truncate mode with batched examples."""
|
||||
handler = partial(
|
||||
truncate_or_drop_long_seq,
|
||||
sequence_len=self.sequence_len,
|
||||
min_sequence_len=self.min_sequence_len,
|
||||
handling="truncate",
|
||||
)
|
||||
sample = {
|
||||
"input_ids": [
|
||||
[1, 2], # Too short
|
||||
list(range(self.sequence_len + 5)), # Too long
|
||||
list(range(self.sequence_len)), # OK
|
||||
list(range(self.min_sequence_len)), # OK
|
||||
[], # Empty
|
||||
],
|
||||
"labels": [ # Add labels to test truncation
|
||||
[1, 2],
|
||||
list(range(self.sequence_len + 5)),
|
||||
list(range(self.sequence_len)),
|
||||
list(range(self.min_sequence_len)),
|
||||
[],
|
||||
],
|
||||
}
|
||||
|
||||
result = handler(sample)
|
||||
|
||||
# Expected results after truncation (too short and empty remain unchanged by this function)
|
||||
expected_input_ids = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
expected_labels = [
|
||||
[1, 2], # Unchanged (too short)
|
||||
list(range(self.sequence_len)), # Truncated
|
||||
list(range(self.sequence_len)), # Unchanged (OK)
|
||||
list(range(self.min_sequence_len)), # Unchanged (OK)
|
||||
[], # Unchanged (Empty)
|
||||
]
|
||||
|
||||
|
||||
self.assertEqual(result["input_ids"], expected_input_ids)
|
||||
self.assertEqual(result["labels"], expected_labels)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user