Fix: excess_length_strategy truncation method (#3401)

* Add test cases to verify that the problem exists in the underlying

* Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before.

* fix: refactor and add test truncate for non-input id fields

* fix: refactor long seq handling fn

* fix: refactor duplicate fn and simplify route

* add additional tests and make them work on mac

* handle logging exception on empty datasets

---------

Co-authored-by: 2ndset bot <bot@2ndset.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Robert Ronan
2026-02-24 23:31:11 -05:00
committed by GitHub
parent 8f54b4eb25
commit 2b6f4a6c9b
4 changed files with 722 additions and 91 deletions

View File

@@ -7,7 +7,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import filter_sequences_by_length
from tests.hf_offline_utils import enable_hf_offline
@@ -70,17 +70,19 @@ class TestEncodePretraining(unittest.TestCase):
# -- single sequence --
# This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
drop_long_seq(data, 32, raise_on_drop=True)
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should return True, since data fits
dropped = drop_long_seq(data, 32)
dropped = filter_sequences_by_length(data, 32)
self.assertTrue(dropped)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should return False, since data doesn't fit
dropped = drop_long_seq(data, 15)
dropped = filter_sequences_by_length(data, 15)
self.assertFalse(dropped)
# -- batch sequence --
@@ -91,13 +93,15 @@ class TestEncodePretraining(unittest.TestCase):
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
]
}
drop_long_seq(data, 32, raise_on_drop=True)
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should keep the first but drop the second entry
dropped = drop_long_seq(data, 15)
dropped = filter_sequences_by_length(data, 15)
self.assertEqual(dropped, [True, False])