Fix: excess_length_strategy truncation method (#3401)

* Add test cases to verify that the problem exists in the underlying

* Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before.

* fix: refactor and add test truncate for non-input id fields

* fix: refactor long seq handling fn

* fix: refactor duplicate fn and simplify route

* add additional tests and make them work on mac

* handle logging exception on empty datasets

---------

Co-authored-by: 2ndset bot <bot@2ndset.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Robert Ronan
2026-02-24 23:31:11 -05:00
committed by GitHub
parent 8f54b4eb25
commit 2b6f4a6c9b
4 changed files with 722 additions and 91 deletions

View File

@@ -15,7 +15,7 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import filter_sequences_by_length
LOG = get_logger(__name__)
@@ -148,22 +148,33 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def keep_min_len(sample, min_sequence_len=2):
"""
Truncate samples whose sequence length is too long (> sequence_len)
or drop those too short (< min_sequence_len).
Batched filter function that keeps only samples with sequence length >= min_sequence_len.
Returns a list of booleans indicating which samples to keep.
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
results.append(len(seq) >= min_sequence_len)
return results
def truncate_long_seq(sample, sequence_len=2048):
"""
Truncate samples whose sequence length is too long (> sequence_len).
Modifies the sample in-place and returns the modified sample.
"""
input_ids = sample["input_ids"]
# Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids):
length = len(seq)
if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
if length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
@@ -171,10 +182,133 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
results.append(True)
else:
results.append(True)
return results
return sample
def _should_skip_processing(dataset: Dataset) -> bool:
"""Check if dataset should skip long sequence handling."""
if (
hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return True
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
return True
return False
def _log_dataset_stats(dataset: Dataset) -> None:
"""Log min/max sequence lengths for debugging."""
with contextlib.suppress(AttributeError, ValueError):
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
LOG.info(f"min_input_len: {np.min(ds_lengths)}")
LOG.info(f"max_input_len: {np.max(ds_lengths)}")
def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict:
"""Build kwargs for dataset filter/map operations."""
kwargs = {}
if not isinstance(dataset, IterableDataset):
kwargs["num_proc"] = cfg.dataset_num_proc
kwargs["load_from_cache_file"] = not cfg.is_preprocess
return kwargs
def _filter_short_sequences(
dataset: Dataset, min_len: int, filter_kwargs: dict
) -> tuple[Dataset, int]:
"""Filter out sequences shorter than min_len. Returns (dataset, num_dropped)."""
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
desc_kwargs = {}
if filter_kwargs:
desc_kwargs["desc"] = f"Filtering Short Sequences (<{min_len})"
dataset = dataset.filter(
functools.partial(keep_min_len, min_sequence_len=min_len),
batched=True,
**filter_kwargs,
**desc_kwargs,
)
dropped = 0
if prior_len:
dropped = prior_len - len(dataset)
if dropped > 0:
LOG.info(f"Dropped {dropped} short sequences (<{min_len} tokens)")
return dataset, dropped
def _truncate_long_sequences(
dataset: Dataset, max_len: int, map_kwargs: dict
) -> Dataset:
"""Truncate sequences longer than max_len."""
desc_kwargs = {}
if map_kwargs:
desc_kwargs["desc"] = f"Truncating Sequences (target_len={max_len})"
dataset = dataset.map(
functools.partial(truncate_long_seq, sequence_len=max_len),
batched=True,
**map_kwargs,
**desc_kwargs,
)
LOG.info(f"Truncated long sequences to max length {max_len}")
return dataset
def _drop_outside_range(
dataset: Dataset,
max_len: int,
min_len: int,
raise_on_long: bool,
filter_kwargs: dict,
) -> tuple[Dataset, int]:
"""Drop sequences outside valid length range [min_len, max_len].
Returns (dataset, num_dropped)."""
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
desc_kwargs = {}
if filter_kwargs:
action = (
"Checking Sequence Lengths"
if raise_on_long
else "Dropping Invalid Sequences"
)
desc_kwargs["desc"] = f"{action} (<{min_len} or >{max_len})"
dataset = dataset.filter(
functools.partial(
filter_sequences_by_length,
sequence_len=max_len,
min_sequence_len=min_len,
raise_on_drop=raise_on_long,
),
batched=True,
**filter_kwargs,
**desc_kwargs,
)
dropped = 0
if not raise_on_long and prior_len:
dropped = prior_len - len(dataset)
if dropped > 0:
LOG.info(
f"Dropped {dropped} sequences outside valid range "
f"([{min_len}, {max_len}])"
)
return dataset, dropped
def handle_long_seq_in_dataset(
@@ -193,80 +327,25 @@ def handle_long_seq_in_dataset(
'truncate' truncates them down to sequence_len
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
"""
if (
hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
# Early returns for special cases
if _should_skip_processing(dataset):
return dataset
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
drop_long = functools.partial(
drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
raise_on_drop=excess_length_strategy == "raise",
)
_log_dataset_stats(dataset)
with contextlib.suppress(AttributeError):
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
min_input_len = np.min(ds_lengths)
LOG.info(f"min_input_len: {min_input_len}")
max_input_len = np.max(ds_lengths)
LOG.info(f"max_input_len: {max_input_len}")
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
action = (
"Checking Sequence Lengths"
if excess_length_strategy == "raise"
else "Dropping Long Sequences"
)
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
# Setup kwargs
filter_kwargs = _build_filter_kwargs(dataset, cfg)
# Handle sequences based on strategy
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs)
dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs)
else:
process_fn = drop_long
dataset = dataset.filter(
process_fn,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
action = (
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
raise_on_long = excess_length_strategy == "raise"
dataset, _ = _drop_outside_range(
dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs
)
return dataset

View File

@@ -205,10 +205,13 @@ def add_length(sample):
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
def filter_sequences_by_length(
sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False
):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Filter sequences outside valid length range [min_sequence_len, sequence_len].
Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
@@ -383,10 +386,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
drop_outside_range,
desc="Dropping Long Sequences",
load_from_cache_file=False,
)

View File

@@ -7,7 +7,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import filter_sequences_by_length
from tests.hf_offline_utils import enable_hf_offline
@@ -70,17 +70,19 @@ class TestEncodePretraining(unittest.TestCase):
# -- single sequence --
# This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
drop_long_seq(data, 32, raise_on_drop=True)
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should return True, since data fits
dropped = drop_long_seq(data, 32)
dropped = filter_sequences_by_length(data, 32)
self.assertTrue(dropped)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should return False, since data doesn't fit
dropped = drop_long_seq(data, 15)
dropped = filter_sequences_by_length(data, 15)
self.assertFalse(dropped)
# -- batch sequence --
@@ -91,13 +93,15 @@ class TestEncodePretraining(unittest.TestCase):
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
]
}
drop_long_seq(data, 32, raise_on_drop=True)
filter_sequences_by_length(data, 32, raise_on_drop=True)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
self.assertRaises(
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
)
# This should keep the first but drop the second entry
dropped = drop_long_seq(data, 15)
dropped = filter_sequences_by_length(data, 15)
self.assertEqual(dropped, [True, False])

View File

@@ -0,0 +1,545 @@
"""
Unit tests for data utility functions
"""
import unittest
from unittest.mock import MagicMock
from datasets import Dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
class TestHandleLongSeqInDataset(unittest.TestCase):
"""
Test class for handle_long_seq_in_dataset function
"""
def test_drop_strategy_removes_long_sequences(self):
"""Test that 'drop' strategy removes sequences longer than sequence_len"""
# Create dataset with mixed length sequences
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # length 3 - keep
[1, 2, 3, 4, 5], # length 5 - keep
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - drop
[1, 2], # length 2 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have dropped the sequence with length 11
self.assertEqual(len(result), 3)
self.assertEqual(len(result[0]["input_ids"]), 3)
self.assertEqual(len(result[1]["input_ids"]), 5)
self.assertEqual(len(result[2]["input_ids"]), 2)
def test_drop_strategy_is_default(self):
"""Test that 'drop' is the default strategy when not specified"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should drop
]
}
)
cfg = DictDefault(
{
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have dropped the long sequence
self.assertEqual(len(result), 1)
def test_truncate_strategy_truncates_long_sequences(self):
"""Test that 'truncate' strategy truncates sequences to sequence_len"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # length 3 - keep as is
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
], # length 12 - truncate to 10
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have 2 samples
self.assertEqual(len(result), 2)
# First sample unchanged
self.assertEqual(len(result[0]["input_ids"]), 3)
# Second sample truncated to 10
self.assertEqual(len(result[1]["input_ids"]), 10)
self.assertEqual(result[1]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_truncate_strategy_truncates_all_auxiliary_fields(self):
"""Test that truncation applies to all auxiliary fields consistently"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"attention_mask": [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
"labels": [
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"position_ids": [
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
],
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# All fields should be truncated to 10
self.assertEqual(len(result[0]["input_ids"]), 10)
self.assertEqual(len(result[0]["attention_mask"]), 10)
self.assertEqual(len(result[0]["labels"]), 10)
self.assertEqual(len(result[0]["position_ids"]), 10)
# Verify content is correct
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["attention_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["position_ids"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
def test_raise_strategy_raises_on_long_sequences(self):
"""Test that 'raise' strategy raises ValueError when encountering long sequences"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should raise
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "raise",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
with self.assertRaises(ValueError):
handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
def test_min_sequence_len_filters_short_sequences(self):
"""Test that sequences shorter than min_sample_len are filtered out"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - drop (< min_sample_len=3)
[1, 2], # length 2 - drop
[1, 2, 3], # length 3 - keep
[1, 2, 3, 4, 5], # length 5 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 3,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should only keep sequences with length >= 3
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]["input_ids"]), 3)
self.assertEqual(len(result[1]["input_ids"]), 5)
def test_dataset_without_input_ids_column(self):
"""Test that datasets without 'input_ids' column are returned unchanged"""
dataset = Dataset.from_dict(
{
"chosen": [1, 2, 3],
"rejected": [4, 5, 6],
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Dataset should be unchanged
self.assertEqual(len(result), len(dataset))
self.assertListEqual(list(result.column_names), ["chosen", "rejected"])
def test_truncate_filters_short_before_truncating(self):
"""Test that truncate strategy filters short sequences before truncating long ones
This is important for efficiency - we should not waste time truncating
sequences that will be filtered out anyway.
"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - filter out first
[1, 2, 3], # length 3 - keep, no truncation needed
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
], # length 12 - keep and truncate
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should have filtered out the first (short) sequence
self.assertEqual(len(result), 2)
# Second sample unchanged
self.assertEqual(len(result[0]["input_ids"]), 3)
# Third sample truncated to 10
self.assertEqual(len(result[1]["input_ids"]), 10)
def test_case_insensitive_strategy(self):
"""Test that excess_length_strategy is case-insensitive"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "TRUNCATE", # uppercase
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should still truncate
self.assertEqual(len(result[0]["input_ids"]), 10)
def test_raise_strategy_silently_drops_short_sequences(self):
"""Test that 'raise' strategy drops short sequences without raising"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - too short, should be dropped silently
[1, 2, 3, 4, 5], # length 5 - keep
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "raise",
"min_sample_len": 3,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
# Should NOT raise, just silently drop the short sequence
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 5)
def test_drop_boundary_sequence_equal_to_sequence_len(self):
"""Test that drop strategy keeps sequences with length exactly equal to sequence_len"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 > sequence_len
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Exactly equal should be kept, one over should be dropped
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 10)
def test_truncate_boundary_sequence_equal_to_sequence_len(self):
"""Test that truncate strategy leaves sequences with length exactly equal to sequence_len unchanged"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should be unchanged - not truncated
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_empty_dataset(self):
"""Test that an empty dataset is handled gracefully"""
dataset = Dataset.from_dict({"input_ids": []})
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 0)
def test_all_sequences_dropped_returns_empty_dataset(self):
"""Test that dropping all sequences results in an empty dataset"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # too short
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # too long
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 5,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 0)
def test_iterable_dataset_skips_processing(self):
"""Test that streaming datasets (column_names is None) are returned unchanged.
The skip check in _should_skip_processing triggers when column_names is
None, which happens with true streaming datasets loaded via
load_dataset(..., streaming=True).
"""
mock_dataset = MagicMock()
mock_dataset.column_names = None
cfg = DictDefault(
{
"excess_length_strategy": "drop",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(mock_dataset, sequence_len=10, cfg=cfg)
# Should be returned unchanged (same object)
self.assertIs(result, mock_dataset)
def test_truncate_with_partial_auxiliary_fields(self):
"""Test truncation when only some auxiliary fields are present"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
"labels": [
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
],
# No attention_mask or position_ids
}
)
cfg = DictDefault(
{
"excess_length_strategy": "truncate",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result[0]["input_ids"]), 10)
self.assertEqual(len(result[0]["labels"]), 10)
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
# Confirm no extra columns were introduced
self.assertListEqual(sorted(result.column_names), ["input_ids", "labels"])
def test_min_sample_len_defaults_to_two_when_not_set(self):
"""Test that min_sample_len defaults to 2 when not specified in config"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1], # length 1 - should be dropped (< default 2)
[1, 2], # length 2 - should be kept (>= default 2)
[1, 2, 3], # length 3 - should be kept
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "drop",
# min_sample_len not set
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]["input_ids"]), 2)
self.assertEqual(len(result[1]["input_ids"]), 3)
def test_invalid_strategy_falls_through_to_drop(self):
"""Test that an unrecognized strategy value falls through to drop behavior"""
dataset = Dataset.from_dict(
{
"input_ids": [
[1, 2, 3], # keep
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
], # length 11 - should be dropped
]
}
)
cfg = DictDefault(
{
"excess_length_strategy": "not_a_real_strategy",
"min_sample_len": 2,
"dataset_num_proc": None,
"is_preprocess": False,
}
)
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
# Should behave like 'drop'
self.assertEqual(len(result), 1)
self.assertEqual(len(result[0]["input_ids"]), 3)
if __name__ == "__main__":
unittest.main()