Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying * Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before. * fix: refactor and add test truncate for non-input id fields * fix: refactor long seq handling fn * fix: refactor duplicate fn and simplify route * add additional tests and make them work on mac * handle logging exception on empty datasets --------- Co-authored-by: 2ndset bot <bot@2ndset.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -7,7 +7,7 @@ import unittest
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from axolotl.utils.data import encode_streaming, md5
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
from axolotl.utils.trainer import filter_sequences_by_length
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@@ -70,17 +70,19 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
# -- single sequence --
|
||||
# This should work
|
||||
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
|
||||
drop_long_seq(data, 32, raise_on_drop=True)
|
||||
filter_sequences_by_length(data, 32, raise_on_drop=True)
|
||||
|
||||
# This should return True, since data fits
|
||||
dropped = drop_long_seq(data, 32)
|
||||
dropped = filter_sequences_by_length(data, 32)
|
||||
self.assertTrue(dropped)
|
||||
|
||||
# This should raise
|
||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||
self.assertRaises(
|
||||
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
|
||||
)
|
||||
|
||||
# This should return False, since data doesn't fit
|
||||
dropped = drop_long_seq(data, 15)
|
||||
dropped = filter_sequences_by_length(data, 15)
|
||||
self.assertFalse(dropped)
|
||||
|
||||
# -- batch sequence --
|
||||
@@ -91,13 +93,15 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
]
|
||||
}
|
||||
drop_long_seq(data, 32, raise_on_drop=True)
|
||||
filter_sequences_by_length(data, 32, raise_on_drop=True)
|
||||
|
||||
# This should raise
|
||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||
self.assertRaises(
|
||||
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
|
||||
)
|
||||
|
||||
# This should keep the first but drop the second entry
|
||||
dropped = drop_long_seq(data, 15)
|
||||
dropped = filter_sequences_by_length(data, 15)
|
||||
self.assertEqual(dropped, [True, False])
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user