From 2b6f4a6c9b9acccf9eb1c9f4e730a3a2f25b9488 Mon Sep 17 00:00:00 2001 From: Robert Ronan Date: Tue, 24 Feb 2026 23:31:11 -0500 Subject: [PATCH] Fix: excess_length_strategy truncation method (#3401) * Add test cases to verify that the problem exists in the underlying * Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before. * fix: refactor and add test truncate for non-input id fields * fix: refactor long seq handling fn * fix: refactor duplicate fn and simplify route * add additional tests and make them work on mac * handle logging exception on empty datasets --------- Co-authored-by: 2ndset bot Co-authored-by: NanoCode012 Co-authored-by: Wing Lian --- src/axolotl/utils/data/utils.py | 235 +++++++++----- src/axolotl/utils/trainer.py | 13 +- tests/test_data.py | 20 +- tests/utils/data/test_utils.py | 545 ++++++++++++++++++++++++++++++++ 4 files changed, 722 insertions(+), 91 deletions(-) create mode 100644 tests/utils/data/test_utils.py diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 319e27f6f..f2cdcac38 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -15,7 +15,7 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.samplers.utils import get_dataset_lengths -from axolotl.utils.trainer import drop_long_seq +from axolotl.utils.trainer import filter_sequences_by_length LOG = get_logger(__name__) @@ -148,22 +148,33 @@ def deduplicate_and_log_datasets( return dataset, other_dataset -def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): +def keep_min_len(sample, min_sequence_len=2): """ - Truncate samples whose sequence length is too long (> sequence_len) - or drop those too short (< min_sequence_len). + Batched filter function that keeps only samples with sequence length >= min_sequence_len. + Returns a list of booleans indicating which samples to keep. """ min_sequence_len = min_sequence_len or 2 input_ids = sample["input_ids"] + + # Batched (input_ids is a list of lists) results = [] + for seq in input_ids: + results.append(len(seq) >= min_sequence_len) + return results + + +def truncate_long_seq(sample, sequence_len=2048): + """ + Truncate samples whose sequence length is too long (> sequence_len). + Modifies the sample in-place and returns the modified sample. + """ + input_ids = sample["input_ids"] # Batched (input_ids is a list of lists) for i, seq in enumerate(input_ids): length = len(seq) - if length < min_sequence_len: - results.append(False) - elif length > sequence_len: + if length > sequence_len: sample["input_ids"][i] = seq[:sequence_len] if "attention_mask" in sample: sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] @@ -171,10 +182,133 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): sample["labels"][i] = sample["labels"][i][:sequence_len] if "position_ids" in sample: sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] - results.append(True) - else: - results.append(True) - return results + return sample + + +def _should_skip_processing(dataset: Dataset) -> bool: + """Check if dataset should skip long sequence handling.""" + if ( + hasattr(dataset, "column_names") + and dataset.column_names + and "input_ids" not in dataset.column_names + ): + LOG.warning( + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " + "expected for reward modeling." + ) + return True + elif not hasattr(dataset, "column_names") or dataset.column_names is None: + LOG.info( + "Dataset is streaming (IterableDataset), skipping long sequence handling" + ) + return True + return False + + +def _log_dataset_stats(dataset: Dataset) -> None: + """Log min/max sequence lengths for debugging.""" + with contextlib.suppress(AttributeError, ValueError): + ds_lengths = get_dataset_lengths(dataset, from_arrow=True) + LOG.info(f"min_input_len: {np.min(ds_lengths)}") + LOG.info(f"max_input_len: {np.max(ds_lengths)}") + + +def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict: + """Build kwargs for dataset filter/map operations.""" + kwargs = {} + if not isinstance(dataset, IterableDataset): + kwargs["num_proc"] = cfg.dataset_num_proc + kwargs["load_from_cache_file"] = not cfg.is_preprocess + return kwargs + + +def _filter_short_sequences( + dataset: Dataset, min_len: int, filter_kwargs: dict +) -> tuple[Dataset, int]: + """Filter out sequences shorter than min_len. Returns (dataset, num_dropped).""" + prior_len = len(dataset) if hasattr(dataset, "__len__") else None + + desc_kwargs = {} + if filter_kwargs: + desc_kwargs["desc"] = f"Filtering Short Sequences (<{min_len})" + + dataset = dataset.filter( + functools.partial(keep_min_len, min_sequence_len=min_len), + batched=True, + **filter_kwargs, + **desc_kwargs, + ) + + dropped = 0 + if prior_len: + dropped = prior_len - len(dataset) + if dropped > 0: + LOG.info(f"Dropped {dropped} short sequences (<{min_len} tokens)") + + return dataset, dropped + + +def _truncate_long_sequences( + dataset: Dataset, max_len: int, map_kwargs: dict +) -> Dataset: + """Truncate sequences longer than max_len.""" + desc_kwargs = {} + if map_kwargs: + desc_kwargs["desc"] = f"Truncating Sequences (target_len={max_len})" + + dataset = dataset.map( + functools.partial(truncate_long_seq, sequence_len=max_len), + batched=True, + **map_kwargs, + **desc_kwargs, + ) + LOG.info(f"Truncated long sequences to max length {max_len}") + return dataset + + +def _drop_outside_range( + dataset: Dataset, + max_len: int, + min_len: int, + raise_on_long: bool, + filter_kwargs: dict, +) -> tuple[Dataset, int]: + """Drop sequences outside valid length range [min_len, max_len]. + + Returns (dataset, num_dropped).""" + prior_len = len(dataset) if hasattr(dataset, "__len__") else None + + desc_kwargs = {} + if filter_kwargs: + action = ( + "Checking Sequence Lengths" + if raise_on_long + else "Dropping Invalid Sequences" + ) + desc_kwargs["desc"] = f"{action} (<{min_len} or >{max_len})" + + dataset = dataset.filter( + functools.partial( + filter_sequences_by_length, + sequence_len=max_len, + min_sequence_len=min_len, + raise_on_drop=raise_on_long, + ), + batched=True, + **filter_kwargs, + **desc_kwargs, + ) + + dropped = 0 + if not raise_on_long and prior_len: + dropped = prior_len - len(dataset) + if dropped > 0: + LOG.info( + f"Dropped {dropped} sequences outside valid range " + f"([{min_len}, {max_len}])" + ) + + return dataset, dropped def handle_long_seq_in_dataset( @@ -193,80 +327,25 @@ def handle_long_seq_in_dataset( 'truncate' truncates them down to sequence_len 'raise' raises a ValueError if any sequence was found that was longer than sequence_len """ - if ( - hasattr(dataset, "column_names") - and dataset.column_names - and "input_ids" not in dataset.column_names - ): - LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " - "expected for reward modeling." - ) - return dataset - elif not hasattr(dataset, "column_names") or dataset.column_names is None: - LOG.info( - "Dataset is streaming (IterableDataset), skipping long sequence handling" - ) + # Early returns for special cases + if _should_skip_processing(dataset): return dataset excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() - drop_long = functools.partial( - drop_long_seq, - sequence_len=sequence_len, - min_sequence_len=cfg.min_sample_len, - raise_on_drop=excess_length_strategy == "raise", - ) + _log_dataset_stats(dataset) - with contextlib.suppress(AttributeError): - ds_lengths = get_dataset_lengths(dataset, from_arrow=True) - min_input_len = np.min(ds_lengths) - LOG.info(f"min_input_len: {min_input_len}") - max_input_len = np.max(ds_lengths) - LOG.info(f"max_input_len: {max_input_len}") - - prior_len = len(dataset) if hasattr(dataset, "__len__") else None - - filter_map_kwargs = {} - if not isinstance(dataset, IterableDataset): - filter_map_kwargs["num_proc"] = cfg.dataset_num_proc - filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess - - drop_long_kwargs = {} - if filter_map_kwargs: - action = ( - "Checking Sequence Lengths" - if excess_length_strategy == "raise" - else "Dropping Long Sequences" - ) - drop_long_kwargs["desc"] = f"{action} (>{sequence_len})" + # Setup kwargs + filter_kwargs = _build_filter_kwargs(dataset, cfg) + # Handle sequences based on strategy if excess_length_strategy == "truncate": - process_fn = functools.partial( - truncate_long_seq, - sequence_len=sequence_len, - min_sequence_len=cfg.min_sample_len, - ) - drop_long_kwargs["desc"] = ( - f"Truncating/Filtering Sequences (target_len={sequence_len})" - ) + dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs) + dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs) else: - process_fn = drop_long - - dataset = dataset.filter( - process_fn, - batched=True, - **filter_map_kwargs, - **drop_long_kwargs, - ) - if prior_len: - dropped = prior_len - len(dataset) - if dropped: - action = ( - "truncated/filtered" - if excess_length_strategy == "truncate" - else "dropped" - ) - LOG.warning(f"{action.title()} {dropped} samples from dataset") + raise_on_long = excess_length_strategy == "raise" + dataset, _ = _drop_outside_range( + dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs + ) return dataset diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 621ef8785..d97a74f6f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -205,10 +205,13 @@ def add_length(sample): return sample -def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False): +def filter_sequences_by_length( + sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False +): """ - Drop samples whose sequence length is either too long (> sequence_len) - or too short (< min_sequence_len). + Filter sequences outside valid length range [min_sequence_len, sequence_len]. + + Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len). Works for both single-example (list[int]) or batched (list[list[int]]). @@ -383,10 +386,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_pretraining_datasets_for_packing( train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False ): - drop_long = partial(drop_long_seq, sequence_len=sequence_len) + drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len) train_dataset = train_dataset.filter( - drop_long, + drop_outside_range, desc="Dropping Long Sequences", load_from_cache_file=False, ) diff --git a/tests/test_data.py b/tests/test_data.py index ad76bbf6e..01f60e897 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -7,7 +7,7 @@ import unittest from transformers import LlamaTokenizer from axolotl.utils.data import encode_streaming, md5 -from axolotl.utils.trainer import drop_long_seq +from axolotl.utils.trainer import filter_sequences_by_length from tests.hf_offline_utils import enable_hf_offline @@ -70,17 +70,19 @@ class TestEncodePretraining(unittest.TestCase): # -- single sequence -- # This should work data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]} - drop_long_seq(data, 32, raise_on_drop=True) + filter_sequences_by_length(data, 32, raise_on_drop=True) # This should return True, since data fits - dropped = drop_long_seq(data, 32) + dropped = filter_sequences_by_length(data, 32) self.assertTrue(dropped) # This should raise - self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True) + self.assertRaises( + ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True + ) # This should return False, since data doesn't fit - dropped = drop_long_seq(data, 15) + dropped = filter_sequences_by_length(data, 15) self.assertFalse(dropped) # -- batch sequence -- @@ -91,13 +93,15 @@ class TestEncodePretraining(unittest.TestCase): [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], ] } - drop_long_seq(data, 32, raise_on_drop=True) + filter_sequences_by_length(data, 32, raise_on_drop=True) # This should raise - self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True) + self.assertRaises( + ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True + ) # This should keep the first but drop the second entry - dropped = drop_long_seq(data, 15) + dropped = filter_sequences_by_length(data, 15) self.assertEqual(dropped, [True, False]) diff --git a/tests/utils/data/test_utils.py b/tests/utils/data/test_utils.py new file mode 100644 index 000000000..357447b47 --- /dev/null +++ b/tests/utils/data/test_utils.py @@ -0,0 +1,545 @@ +""" +Unit tests for data utility functions +""" + +import unittest +from unittest.mock import MagicMock + +from datasets import Dataset + +from axolotl.utils.data.utils import handle_long_seq_in_dataset +from axolotl.utils.dict import DictDefault + + +class TestHandleLongSeqInDataset(unittest.TestCase): + """ + Test class for handle_long_seq_in_dataset function + """ + + def test_drop_strategy_removes_long_sequences(self): + """Test that 'drop' strategy removes sequences longer than sequence_len""" + # Create dataset with mixed length sequences + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3], # length 3 - keep + [1, 2, 3, 4, 5], # length 5 - keep + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - drop + [1, 2], # length 2 - keep + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should have dropped the sequence with length 11 + self.assertEqual(len(result), 3) + self.assertEqual(len(result[0]["input_ids"]), 3) + self.assertEqual(len(result[1]["input_ids"]), 5) + self.assertEqual(len(result[2]["input_ids"]), 2) + + def test_drop_strategy_is_default(self): + """Test that 'drop' is the default strategy when not specified""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should drop + ] + } + ) + + cfg = DictDefault( + { + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should have dropped the long sequence + self.assertEqual(len(result), 1) + + def test_truncate_strategy_truncates_long_sequences(self): + """Test that 'truncate' strategy truncates sequences to sequence_len""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3], # length 3 - keep as is + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + ], # length 12 - truncate to 10 + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "truncate", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should have 2 samples + self.assertEqual(len(result), 2) + # First sample unchanged + self.assertEqual(len(result[0]["input_ids"]), 3) + # Second sample truncated to 10 + self.assertEqual(len(result[1]["input_ids"]), 10) + self.assertEqual(result[1]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + def test_truncate_strategy_truncates_all_auxiliary_fields(self): + """Test that truncation applies to all auxiliary fields consistently""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ], + "attention_mask": [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + ], + "labels": [ + [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ], + "position_ids": [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ], + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "truncate", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # All fields should be truncated to 10 + self.assertEqual(len(result[0]["input_ids"]), 10) + self.assertEqual(len(result[0]["attention_mask"]), 10) + self.assertEqual(len(result[0]["labels"]), 10) + self.assertEqual(len(result[0]["position_ids"]), 10) + + # Verify content is correct + self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + self.assertEqual(result[0]["attention_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10]) + self.assertEqual(result[0]["position_ids"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + def test_raise_strategy_raises_on_long_sequences(self): + """Test that 'raise' strategy raises ValueError when encountering long sequences""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should raise + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "raise", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + with self.assertRaises(ValueError): + handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + def test_min_sequence_len_filters_short_sequences(self): + """Test that sequences shorter than min_sample_len are filtered out""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1], # length 1 - drop (< min_sample_len=3) + [1, 2], # length 2 - drop + [1, 2, 3], # length 3 - keep + [1, 2, 3, 4, 5], # length 5 - keep + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + "min_sample_len": 3, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should only keep sequences with length >= 3 + self.assertEqual(len(result), 2) + self.assertEqual(len(result[0]["input_ids"]), 3) + self.assertEqual(len(result[1]["input_ids"]), 5) + + def test_dataset_without_input_ids_column(self): + """Test that datasets without 'input_ids' column are returned unchanged""" + dataset = Dataset.from_dict( + { + "chosen": [1, 2, 3], + "rejected": [4, 5, 6], + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + "min_sample_len": 2, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Dataset should be unchanged + self.assertEqual(len(result), len(dataset)) + self.assertListEqual(list(result.column_names), ["chosen", "rejected"]) + + def test_truncate_filters_short_before_truncating(self): + """Test that truncate strategy filters short sequences before truncating long ones + + This is important for efficiency - we should not waste time truncating + sequences that will be filtered out anyway. + """ + dataset = Dataset.from_dict( + { + "input_ids": [ + [1], # length 1 - filter out first + [1, 2, 3], # length 3 - keep, no truncation needed + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + ], # length 12 - keep and truncate + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "truncate", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should have filtered out the first (short) sequence + self.assertEqual(len(result), 2) + # Second sample unchanged + self.assertEqual(len(result[0]["input_ids"]), 3) + # Third sample truncated to 10 + self.assertEqual(len(result[1]["input_ids"]), 10) + + def test_case_insensitive_strategy(self): + """Test that excess_length_strategy is case-insensitive""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "TRUNCATE", # uppercase + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should still truncate + self.assertEqual(len(result[0]["input_ids"]), 10) + + def test_raise_strategy_silently_drops_short_sequences(self): + """Test that 'raise' strategy drops short sequences without raising""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1], # length 1 - too short, should be dropped silently + [1, 2, 3, 4, 5], # length 5 - keep + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "raise", + "min_sample_len": 3, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + # Should NOT raise, just silently drop the short sequence + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]["input_ids"]), 5) + + def test_drop_boundary_sequence_equal_to_sequence_len(self): + """Test that drop strategy keeps sequences with length exactly equal to sequence_len""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 > sequence_len + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Exactly equal should be kept, one over should be dropped + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]["input_ids"]), 10) + + def test_truncate_boundary_sequence_equal_to_sequence_len(self): + """Test that truncate strategy leaves sequences with length exactly equal to sequence_len unchanged""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "truncate", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should be unchanged - not truncated + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + def test_empty_dataset(self): + """Test that an empty dataset is handled gracefully""" + dataset = Dataset.from_dict({"input_ids": []}) + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + self.assertEqual(len(result), 0) + + def test_all_sequences_dropped_returns_empty_dataset(self): + """Test that dropping all sequences results in an empty dataset""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1], # too short + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # too long + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + "min_sample_len": 5, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + self.assertEqual(len(result), 0) + + def test_iterable_dataset_skips_processing(self): + """Test that streaming datasets (column_names is None) are returned unchanged. + + The skip check in _should_skip_processing triggers when column_names is + None, which happens with true streaming datasets loaded via + load_dataset(..., streaming=True). + """ + mock_dataset = MagicMock() + mock_dataset.column_names = None + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(mock_dataset, sequence_len=10, cfg=cfg) + + # Should be returned unchanged (same object) + self.assertIs(result, mock_dataset) + + def test_truncate_with_partial_auxiliary_fields(self): + """Test truncation when only some auxiliary fields are present""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ], + "labels": [ + [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ], + # No attention_mask or position_ids + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "truncate", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + self.assertEqual(len(result[0]["input_ids"]), 10) + self.assertEqual(len(result[0]["labels"]), 10) + self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10]) + # Confirm no extra columns were introduced + self.assertListEqual(sorted(result.column_names), ["input_ids", "labels"]) + + def test_min_sample_len_defaults_to_two_when_not_set(self): + """Test that min_sample_len defaults to 2 when not specified in config""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1], # length 1 - should be dropped (< default 2) + [1, 2], # length 2 - should be kept (>= default 2) + [1, 2, 3], # length 3 - should be kept + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "drop", + # min_sample_len not set + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + self.assertEqual(len(result), 2) + self.assertEqual(len(result[0]["input_ids"]), 2) + self.assertEqual(len(result[1]["input_ids"]), 3) + + def test_invalid_strategy_falls_through_to_drop(self): + """Test that an unrecognized strategy value falls through to drop behavior""" + dataset = Dataset.from_dict( + { + "input_ids": [ + [1, 2, 3], # keep + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + ], # length 11 - should be dropped + ] + } + ) + + cfg = DictDefault( + { + "excess_length_strategy": "not_a_real_strategy", + "min_sample_len": 2, + "dataset_num_proc": None, + "is_preprocess": False, + } + ) + + result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg) + + # Should behave like 'drop' + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]["input_ids"]), 3) + + +if __name__ == "__main__": + unittest.main()