feature: raise on long sequence drop (#3321)

* feature: raise on long sequence drop

It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated.

* tests: add excess_length_strategy tests

* doc: updated return value description for drop_long_seq_in_dataset

* add @enable_hf_offline

* fixed cfg modified after validate_config called

* hf offline fix

* fix tqdm desc when raise is used

* test: added test for non-batched case

* accidental code change revert

* test: use pytest.raises

* test: simplified drop_seq_len tests

* test: moved excess_length_strat test to test_data.py

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
kallewoof
2025-12-23 03:59:49 +09:00
committed by GitHub
parent efeb5a4e41
commit 92ee4256f7
5 changed files with 67 additions and 7 deletions

View File

@@ -7,6 +7,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import drop_long_seq
from tests.hf_offline_utils import enable_hf_offline
@@ -63,6 +64,42 @@ class TestEncodePretraining(unittest.TestCase):
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
def test_excess_length_strategy(self):
"""Test that excess_length_strategy results in a value error when set to 'raise'."""
# -- single sequence --
# This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
drop_long_seq(data, 32, raise_on_drop=True)
# This should return True, since data fits
dropped = drop_long_seq(data, 32)
self.assertTrue(dropped)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
# This should return False, since data doesn't fit
dropped = drop_long_seq(data, 15)
self.assertFalse(dropped)
# -- batch sequence --
# This should work
data = {
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
]
}
drop_long_seq(data, 32, raise_on_drop=True)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
# This should keep the first but drop the second entry
dropped = drop_long_seq(data, 15)
self.assertEqual(dropped, [True, False])
if __name__ == "__main__":
unittest.main()