Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying * Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before. * fix: refactor and add test truncate for non-input id fields * fix: refactor long seq handling fn * fix: refactor duplicate fn and simplify route * add additional tests and make them work on mac * handle logging exception on empty datasets --------- Co-authored-by: 2ndset bot <bot@2ndset.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -15,7 +15,7 @@ from datasets import Dataset, IterableDataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
from axolotl.utils.trainer import filter_sequences_by_length
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -148,22 +148,33 @@ def deduplicate_and_log_datasets(
|
||||
return dataset, other_dataset
|
||||
|
||||
|
||||
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
def keep_min_len(sample, min_sequence_len=2):
|
||||
"""
|
||||
Truncate samples whose sequence length is too long (> sequence_len)
|
||||
or drop those too short (< min_sequence_len).
|
||||
Batched filter function that keeps only samples with sequence length >= min_sequence_len.
|
||||
Returns a list of booleans indicating which samples to keep.
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
results.append(len(seq) >= min_sequence_len)
|
||||
return results
|
||||
|
||||
|
||||
def truncate_long_seq(sample, sequence_len=2048):
|
||||
"""
|
||||
Truncate samples whose sequence length is too long (> sequence_len).
|
||||
Modifies the sample in-place and returns the modified sample.
|
||||
"""
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
for i, seq in enumerate(input_ids):
|
||||
length = len(seq)
|
||||
if length < min_sequence_len:
|
||||
results.append(False)
|
||||
elif length > sequence_len:
|
||||
if length > sequence_len:
|
||||
sample["input_ids"][i] = seq[:sequence_len]
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||
@@ -171,10 +182,133 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||
results.append(True)
|
||||
else:
|
||||
results.append(True)
|
||||
return results
|
||||
return sample
|
||||
|
||||
|
||||
def _should_skip_processing(dataset: Dataset) -> bool:
|
||||
"""Check if dataset should skip long sequence handling."""
|
||||
if (
|
||||
hasattr(dataset, "column_names")
|
||||
and dataset.column_names
|
||||
and "input_ids" not in dataset.column_names
|
||||
):
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||
"expected for reward modeling."
|
||||
)
|
||||
return True
|
||||
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
|
||||
LOG.info(
|
||||
"Dataset is streaming (IterableDataset), skipping long sequence handling"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _log_dataset_stats(dataset: Dataset) -> None:
|
||||
"""Log min/max sequence lengths for debugging."""
|
||||
with contextlib.suppress(AttributeError, ValueError):
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
LOG.info(f"min_input_len: {np.min(ds_lengths)}")
|
||||
LOG.info(f"max_input_len: {np.max(ds_lengths)}")
|
||||
|
||||
|
||||
def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict:
|
||||
"""Build kwargs for dataset filter/map operations."""
|
||||
kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
kwargs["num_proc"] = cfg.dataset_num_proc
|
||||
kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
return kwargs
|
||||
|
||||
|
||||
def _filter_short_sequences(
|
||||
dataset: Dataset, min_len: int, filter_kwargs: dict
|
||||
) -> tuple[Dataset, int]:
|
||||
"""Filter out sequences shorter than min_len. Returns (dataset, num_dropped)."""
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
|
||||
desc_kwargs = {}
|
||||
if filter_kwargs:
|
||||
desc_kwargs["desc"] = f"Filtering Short Sequences (<{min_len})"
|
||||
|
||||
dataset = dataset.filter(
|
||||
functools.partial(keep_min_len, min_sequence_len=min_len),
|
||||
batched=True,
|
||||
**filter_kwargs,
|
||||
**desc_kwargs,
|
||||
)
|
||||
|
||||
dropped = 0
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped > 0:
|
||||
LOG.info(f"Dropped {dropped} short sequences (<{min_len} tokens)")
|
||||
|
||||
return dataset, dropped
|
||||
|
||||
|
||||
def _truncate_long_sequences(
|
||||
dataset: Dataset, max_len: int, map_kwargs: dict
|
||||
) -> Dataset:
|
||||
"""Truncate sequences longer than max_len."""
|
||||
desc_kwargs = {}
|
||||
if map_kwargs:
|
||||
desc_kwargs["desc"] = f"Truncating Sequences (target_len={max_len})"
|
||||
|
||||
dataset = dataset.map(
|
||||
functools.partial(truncate_long_seq, sequence_len=max_len),
|
||||
batched=True,
|
||||
**map_kwargs,
|
||||
**desc_kwargs,
|
||||
)
|
||||
LOG.info(f"Truncated long sequences to max length {max_len}")
|
||||
return dataset
|
||||
|
||||
|
||||
def _drop_outside_range(
|
||||
dataset: Dataset,
|
||||
max_len: int,
|
||||
min_len: int,
|
||||
raise_on_long: bool,
|
||||
filter_kwargs: dict,
|
||||
) -> tuple[Dataset, int]:
|
||||
"""Drop sequences outside valid length range [min_len, max_len].
|
||||
|
||||
Returns (dataset, num_dropped)."""
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
|
||||
desc_kwargs = {}
|
||||
if filter_kwargs:
|
||||
action = (
|
||||
"Checking Sequence Lengths"
|
||||
if raise_on_long
|
||||
else "Dropping Invalid Sequences"
|
||||
)
|
||||
desc_kwargs["desc"] = f"{action} (<{min_len} or >{max_len})"
|
||||
|
||||
dataset = dataset.filter(
|
||||
functools.partial(
|
||||
filter_sequences_by_length,
|
||||
sequence_len=max_len,
|
||||
min_sequence_len=min_len,
|
||||
raise_on_drop=raise_on_long,
|
||||
),
|
||||
batched=True,
|
||||
**filter_kwargs,
|
||||
**desc_kwargs,
|
||||
)
|
||||
|
||||
dropped = 0
|
||||
if not raise_on_long and prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped > 0:
|
||||
LOG.info(
|
||||
f"Dropped {dropped} sequences outside valid range "
|
||||
f"([{min_len}, {max_len}])"
|
||||
)
|
||||
|
||||
return dataset, dropped
|
||||
|
||||
|
||||
def handle_long_seq_in_dataset(
|
||||
@@ -193,80 +327,25 @@ def handle_long_seq_in_dataset(
|
||||
'truncate' truncates them down to sequence_len
|
||||
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
|
||||
"""
|
||||
if (
|
||||
hasattr(dataset, "column_names")
|
||||
and dataset.column_names
|
||||
and "input_ids" not in dataset.column_names
|
||||
):
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||
"expected for reward modeling."
|
||||
)
|
||||
return dataset
|
||||
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
|
||||
LOG.info(
|
||||
"Dataset is streaming (IterableDataset), skipping long sequence handling"
|
||||
)
|
||||
# Early returns for special cases
|
||||
if _should_skip_processing(dataset):
|
||||
return dataset
|
||||
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
raise_on_drop=excess_length_strategy == "raise",
|
||||
)
|
||||
_log_dataset_stats(dataset)
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
min_input_len = np.min(ds_lengths)
|
||||
LOG.info(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(ds_lengths)
|
||||
LOG.info(f"max_input_len: {max_input_len}")
|
||||
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
action = (
|
||||
"Checking Sequence Lengths"
|
||||
if excess_length_strategy == "raise"
|
||||
else "Dropping Long Sequences"
|
||||
)
|
||||
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
|
||||
# Setup kwargs
|
||||
filter_kwargs = _build_filter_kwargs(dataset, cfg)
|
||||
|
||||
# Handle sequences based on strategy
|
||||
if excess_length_strategy == "truncate":
|
||||
process_fn = functools.partial(
|
||||
truncate_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
drop_long_kwargs["desc"] = (
|
||||
f"Truncating/Filtering Sequences (target_len={sequence_len})"
|
||||
)
|
||||
dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs)
|
||||
dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs)
|
||||
else:
|
||||
process_fn = drop_long
|
||||
|
||||
dataset = dataset.filter(
|
||||
process_fn,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
action = (
|
||||
"truncated/filtered"
|
||||
if excess_length_strategy == "truncate"
|
||||
else "dropped"
|
||||
)
|
||||
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||
raise_on_long = excess_length_strategy == "raise"
|
||||
dataset, _ = _drop_outside_range(
|
||||
dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -205,10 +205,13 @@ def add_length(sample):
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
||||
def filter_sequences_by_length(
|
||||
sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False
|
||||
):
|
||||
"""
|
||||
Drop samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
Filter sequences outside valid length range [min_sequence_len, sequence_len].
|
||||
|
||||
Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len).
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
|
||||
@@ -383,10 +386,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
def process_pretraining_datasets_for_packing(
|
||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
||||
):
|
||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||
drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len)
|
||||
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
drop_outside_range,
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ import unittest
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
from axolotl.utils.data import encode_streaming, md5
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
from axolotl.utils.trainer import filter_sequences_by_length
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@@ -70,17 +70,19 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
# -- single sequence --
|
||||
# This should work
|
||||
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
|
||||
drop_long_seq(data, 32, raise_on_drop=True)
|
||||
filter_sequences_by_length(data, 32, raise_on_drop=True)
|
||||
|
||||
# This should return True, since data fits
|
||||
dropped = drop_long_seq(data, 32)
|
||||
dropped = filter_sequences_by_length(data, 32)
|
||||
self.assertTrue(dropped)
|
||||
|
||||
# This should raise
|
||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||
self.assertRaises(
|
||||
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
|
||||
)
|
||||
|
||||
# This should return False, since data doesn't fit
|
||||
dropped = drop_long_seq(data, 15)
|
||||
dropped = filter_sequences_by_length(data, 15)
|
||||
self.assertFalse(dropped)
|
||||
|
||||
# -- batch sequence --
|
||||
@@ -91,13 +93,15 @@ class TestEncodePretraining(unittest.TestCase):
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||
]
|
||||
}
|
||||
drop_long_seq(data, 32, raise_on_drop=True)
|
||||
filter_sequences_by_length(data, 32, raise_on_drop=True)
|
||||
|
||||
# This should raise
|
||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||
self.assertRaises(
|
||||
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
|
||||
)
|
||||
|
||||
# This should keep the first but drop the second entry
|
||||
dropped = drop_long_seq(data, 15)
|
||||
dropped = filter_sequences_by_length(data, 15)
|
||||
self.assertEqual(dropped, [True, False])
|
||||
|
||||
|
||||
|
||||
545
tests/utils/data/test_utils.py
Normal file
545
tests/utils/data/test_utils.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""
|
||||
Unit tests for data utility functions
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestHandleLongSeqInDataset(unittest.TestCase):
|
||||
"""
|
||||
Test class for handle_long_seq_in_dataset function
|
||||
"""
|
||||
|
||||
def test_drop_strategy_removes_long_sequences(self):
|
||||
"""Test that 'drop' strategy removes sequences longer than sequence_len"""
|
||||
# Create dataset with mixed length sequences
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3], # length 3 - keep
|
||||
[1, 2, 3, 4, 5], # length 5 - keep
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - drop
|
||||
[1, 2], # length 2 - keep
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should have dropped the sequence with length 11
|
||||
self.assertEqual(len(result), 3)
|
||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||
self.assertEqual(len(result[1]["input_ids"]), 5)
|
||||
self.assertEqual(len(result[2]["input_ids"]), 2)
|
||||
|
||||
def test_drop_strategy_is_default(self):
|
||||
"""Test that 'drop' is the default strategy when not specified"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should drop
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should have dropped the long sequence
|
||||
self.assertEqual(len(result), 1)
|
||||
|
||||
def test_truncate_strategy_truncates_long_sequences(self):
|
||||
"""Test that 'truncate' strategy truncates sequences to sequence_len"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3], # length 3 - keep as is
|
||||
[
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
], # length 12 - truncate to 10
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "truncate",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should have 2 samples
|
||||
self.assertEqual(len(result), 2)
|
||||
# First sample unchanged
|
||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||
# Second sample truncated to 10
|
||||
self.assertEqual(len(result[1]["input_ids"]), 10)
|
||||
self.assertEqual(result[1]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
|
||||
def test_truncate_strategy_truncates_all_auxiliary_fields(self):
|
||||
"""Test that truncation applies to all auxiliary fields consistently"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
"attention_mask": [
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
],
|
||||
"labels": [
|
||||
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
"position_ids": [
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "truncate",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# All fields should be truncated to 10
|
||||
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||
self.assertEqual(len(result[0]["attention_mask"]), 10)
|
||||
self.assertEqual(len(result[0]["labels"]), 10)
|
||||
self.assertEqual(len(result[0]["position_ids"]), 10)
|
||||
|
||||
# Verify content is correct
|
||||
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
self.assertEqual(result[0]["attention_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
|
||||
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
self.assertEqual(result[0]["position_ids"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
|
||||
def test_raise_strategy_raises_on_long_sequences(self):
|
||||
"""Test that 'raise' strategy raises ValueError when encountering long sequences"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3],
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should raise
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "raise",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
def test_min_sequence_len_filters_short_sequences(self):
|
||||
"""Test that sequences shorter than min_sample_len are filtered out"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1], # length 1 - drop (< min_sample_len=3)
|
||||
[1, 2], # length 2 - drop
|
||||
[1, 2, 3], # length 3 - keep
|
||||
[1, 2, 3, 4, 5], # length 5 - keep
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
"min_sample_len": 3,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should only keep sequences with length >= 3
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||
self.assertEqual(len(result[1]["input_ids"]), 5)
|
||||
|
||||
def test_dataset_without_input_ids_column(self):
|
||||
"""Test that datasets without 'input_ids' column are returned unchanged"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"chosen": [1, 2, 3],
|
||||
"rejected": [4, 5, 6],
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
"min_sample_len": 2,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Dataset should be unchanged
|
||||
self.assertEqual(len(result), len(dataset))
|
||||
self.assertListEqual(list(result.column_names), ["chosen", "rejected"])
|
||||
|
||||
def test_truncate_filters_short_before_truncating(self):
|
||||
"""Test that truncate strategy filters short sequences before truncating long ones
|
||||
|
||||
This is important for efficiency - we should not waste time truncating
|
||||
sequences that will be filtered out anyway.
|
||||
"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1], # length 1 - filter out first
|
||||
[1, 2, 3], # length 3 - keep, no truncation needed
|
||||
[
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
], # length 12 - keep and truncate
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "truncate",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should have filtered out the first (short) sequence
|
||||
self.assertEqual(len(result), 2)
|
||||
# Second sample unchanged
|
||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||
# Third sample truncated to 10
|
||||
self.assertEqual(len(result[1]["input_ids"]), 10)
|
||||
|
||||
def test_case_insensitive_strategy(self):
|
||||
"""Test that excess_length_strategy is case-insensitive"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "TRUNCATE", # uppercase
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should still truncate
|
||||
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||
|
||||
def test_raise_strategy_silently_drops_short_sequences(self):
|
||||
"""Test that 'raise' strategy drops short sequences without raising"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1], # length 1 - too short, should be dropped silently
|
||||
[1, 2, 3, 4, 5], # length 5 - keep
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "raise",
|
||||
"min_sample_len": 3,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
# Should NOT raise, just silently drop the short sequence
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(len(result[0]["input_ids"]), 5)
|
||||
|
||||
def test_drop_boundary_sequence_equal_to_sequence_len(self):
|
||||
"""Test that drop strategy keeps sequences with length exactly equal to sequence_len"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 > sequence_len
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Exactly equal should be kept, one over should be dropped
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||
|
||||
def test_truncate_boundary_sequence_equal_to_sequence_len(self):
|
||||
"""Test that truncate strategy leaves sequences with length exactly equal to sequence_len unchanged"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "truncate",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should be unchanged - not truncated
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
|
||||
def test_empty_dataset(self):
|
||||
"""Test that an empty dataset is handled gracefully"""
|
||||
dataset = Dataset.from_dict({"input_ids": []})
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
||||
def test_all_sequences_dropped_returns_empty_dataset(self):
|
||||
"""Test that dropping all sequences results in an empty dataset"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1], # too short
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # too long
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
"min_sample_len": 5,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
||||
def test_iterable_dataset_skips_processing(self):
|
||||
"""Test that streaming datasets (column_names is None) are returned unchanged.
|
||||
|
||||
The skip check in _should_skip_processing triggers when column_names is
|
||||
None, which happens with true streaming datasets loaded via
|
||||
load_dataset(..., streaming=True).
|
||||
"""
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.column_names = None
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(mock_dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should be returned unchanged (same object)
|
||||
self.assertIs(result, mock_dataset)
|
||||
|
||||
def test_truncate_with_partial_auxiliary_fields(self):
|
||||
"""Test truncation when only some auxiliary fields are present"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
"labels": [
|
||||
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
],
|
||||
# No attention_mask or position_ids
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "truncate",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||
self.assertEqual(len(result[0]["labels"]), 10)
|
||||
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
# Confirm no extra columns were introduced
|
||||
self.assertListEqual(sorted(result.column_names), ["input_ids", "labels"])
|
||||
|
||||
def test_min_sample_len_defaults_to_two_when_not_set(self):
|
||||
"""Test that min_sample_len defaults to 2 when not specified in config"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1], # length 1 - should be dropped (< default 2)
|
||||
[1, 2], # length 2 - should be kept (>= default 2)
|
||||
[1, 2, 3], # length 3 - should be kept
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "drop",
|
||||
# min_sample_len not set
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(len(result[0]["input_ids"]), 2)
|
||||
self.assertEqual(len(result[1]["input_ids"]), 3)
|
||||
|
||||
def test_invalid_strategy_falls_through_to_drop(self):
|
||||
"""Test that an unrecognized strategy value falls through to drop behavior"""
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"input_ids": [
|
||||
[1, 2, 3], # keep
|
||||
[
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
], # length 11 - should be dropped
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"excess_length_strategy": "not_a_real_strategy",
|
||||
"min_sample_len": 2,
|
||||
"dataset_num_proc": None,
|
||||
"is_preprocess": False,
|
||||
}
|
||||
)
|
||||
|
||||
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||
|
||||
# Should behave like 'drop'
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user