Files
axolotl/tests/test_data.py
Robert Ronan 2b6f4a6c9b Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying

* Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before.

* fix: refactor and add test truncate for non-input id fields

* fix: refactor long seq handling fn

* fix: refactor duplicate fn and simplify route

* add additional tests and make them work on mac

* handle logging exception on empty datasets

---------

Co-authored-by: 2ndset bot <bot@2ndset.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-02-25 11:31:11 +07:00

110 lines
3.7 KiB
Python

"""
test module for the axolotl.utils.data module
"""
import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import filter_sequences_by_length
from tests.hf_offline_utils import enable_hf_offline
class TestEncodePretraining(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
@enable_hf_offline
def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
self.max_tokens = 15 # set a small number for easy inspection
def test_encode_pretraining(self):
examples = {
"text": [
"Hello, world!",
"Nice to meet you.",
"lorem ipsum dolor sit amet.",
"Nice to meet you again!.",
"hello, hello",
]
}
result = encode_streaming(examples, self.tokenizer, self.max_tokens)
self.assertEqual(len(result["input_ids"]), 3)
# Assert the length of input_ids and attention_mask is correct
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
# Assert EOS and PAD tokens are correctly added
# hello world! is 4 tokens
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
# second part, 5 tokens
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
def test_excess_length_strategy(self):
"""Test that excess_length_strategy results in a value error when set to 'raise'."""
# -- single sequence --
# This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should return True, since data fits
dropped = filter_sequences_by_length(data, 32)
self.assertTrue(dropped)
# This should raise
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should return False, since data doesn't fit
dropped = filter_sequences_by_length(data, 15)
self.assertFalse(dropped)
# -- batch sequence --
# This should work
data = {
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
]
}
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should raise
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should keep the first but drop the second entry
dropped = filter_sequences_by_length(data, 15)
self.assertEqual(dropped, [True, False])
if __name__ == "__main__":
unittest.main()