Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying * Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before. * fix: refactor and add test truncate for non-input id fields * fix: refactor long seq handling fn * fix: refactor duplicate fn and simplify route * add additional tests and make them work on mac * handle logging exception on empty datasets --------- Co-authored-by: 2ndset bot <bot@2ndset.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -15,7 +15,7 @@ from datasets import Dataset, IterableDataset
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||||
from axolotl.utils.trainer import drop_long_seq
|
from axolotl.utils.trainer import filter_sequences_by_length
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
@@ -148,22 +148,33 @@ def deduplicate_and_log_datasets(
|
|||||||
return dataset, other_dataset
|
return dataset, other_dataset
|
||||||
|
|
||||||
|
|
||||||
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
def keep_min_len(sample, min_sequence_len=2):
|
||||||
"""
|
"""
|
||||||
Truncate samples whose sequence length is too long (> sequence_len)
|
Batched filter function that keeps only samples with sequence length >= min_sequence_len.
|
||||||
or drop those too short (< min_sequence_len).
|
Returns a list of booleans indicating which samples to keep.
|
||||||
"""
|
"""
|
||||||
min_sequence_len = min_sequence_len or 2
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
input_ids = sample["input_ids"]
|
input_ids = sample["input_ids"]
|
||||||
|
|
||||||
|
# Batched (input_ids is a list of lists)
|
||||||
results = []
|
results = []
|
||||||
|
for seq in input_ids:
|
||||||
|
results.append(len(seq) >= min_sequence_len)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_long_seq(sample, sequence_len=2048):
|
||||||
|
"""
|
||||||
|
Truncate samples whose sequence length is too long (> sequence_len).
|
||||||
|
Modifies the sample in-place and returns the modified sample.
|
||||||
|
"""
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
|
||||||
# Batched (input_ids is a list of lists)
|
# Batched (input_ids is a list of lists)
|
||||||
for i, seq in enumerate(input_ids):
|
for i, seq in enumerate(input_ids):
|
||||||
length = len(seq)
|
length = len(seq)
|
||||||
if length < min_sequence_len:
|
if length > sequence_len:
|
||||||
results.append(False)
|
|
||||||
elif length > sequence_len:
|
|
||||||
sample["input_ids"][i] = seq[:sequence_len]
|
sample["input_ids"][i] = seq[:sequence_len]
|
||||||
if "attention_mask" in sample:
|
if "attention_mask" in sample:
|
||||||
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||||
@@ -171,10 +182,133 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||||
if "position_ids" in sample:
|
if "position_ids" in sample:
|
||||||
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||||
results.append(True)
|
return sample
|
||||||
else:
|
|
||||||
results.append(True)
|
|
||||||
return results
|
def _should_skip_processing(dataset: Dataset) -> bool:
|
||||||
|
"""Check if dataset should skip long sequence handling."""
|
||||||
|
if (
|
||||||
|
hasattr(dataset, "column_names")
|
||||||
|
and dataset.column_names
|
||||||
|
and "input_ids" not in dataset.column_names
|
||||||
|
):
|
||||||
|
LOG.warning(
|
||||||
|
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||||
|
"expected for reward modeling."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
|
||||||
|
LOG.info(
|
||||||
|
"Dataset is streaming (IterableDataset), skipping long sequence handling"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _log_dataset_stats(dataset: Dataset) -> None:
|
||||||
|
"""Log min/max sequence lengths for debugging."""
|
||||||
|
with contextlib.suppress(AttributeError, ValueError):
|
||||||
|
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||||
|
LOG.info(f"min_input_len: {np.min(ds_lengths)}")
|
||||||
|
LOG.info(f"max_input_len: {np.max(ds_lengths)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict:
|
||||||
|
"""Build kwargs for dataset filter/map operations."""
|
||||||
|
kwargs = {}
|
||||||
|
if not isinstance(dataset, IterableDataset):
|
||||||
|
kwargs["num_proc"] = cfg.dataset_num_proc
|
||||||
|
kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_short_sequences(
|
||||||
|
dataset: Dataset, min_len: int, filter_kwargs: dict
|
||||||
|
) -> tuple[Dataset, int]:
|
||||||
|
"""Filter out sequences shorter than min_len. Returns (dataset, num_dropped)."""
|
||||||
|
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||||
|
|
||||||
|
desc_kwargs = {}
|
||||||
|
if filter_kwargs:
|
||||||
|
desc_kwargs["desc"] = f"Filtering Short Sequences (<{min_len})"
|
||||||
|
|
||||||
|
dataset = dataset.filter(
|
||||||
|
functools.partial(keep_min_len, min_sequence_len=min_len),
|
||||||
|
batched=True,
|
||||||
|
**filter_kwargs,
|
||||||
|
**desc_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
dropped = 0
|
||||||
|
if prior_len:
|
||||||
|
dropped = prior_len - len(dataset)
|
||||||
|
if dropped > 0:
|
||||||
|
LOG.info(f"Dropped {dropped} short sequences (<{min_len} tokens)")
|
||||||
|
|
||||||
|
return dataset, dropped
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_long_sequences(
|
||||||
|
dataset: Dataset, max_len: int, map_kwargs: dict
|
||||||
|
) -> Dataset:
|
||||||
|
"""Truncate sequences longer than max_len."""
|
||||||
|
desc_kwargs = {}
|
||||||
|
if map_kwargs:
|
||||||
|
desc_kwargs["desc"] = f"Truncating Sequences (target_len={max_len})"
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
functools.partial(truncate_long_seq, sequence_len=max_len),
|
||||||
|
batched=True,
|
||||||
|
**map_kwargs,
|
||||||
|
**desc_kwargs,
|
||||||
|
)
|
||||||
|
LOG.info(f"Truncated long sequences to max length {max_len}")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _drop_outside_range(
|
||||||
|
dataset: Dataset,
|
||||||
|
max_len: int,
|
||||||
|
min_len: int,
|
||||||
|
raise_on_long: bool,
|
||||||
|
filter_kwargs: dict,
|
||||||
|
) -> tuple[Dataset, int]:
|
||||||
|
"""Drop sequences outside valid length range [min_len, max_len].
|
||||||
|
|
||||||
|
Returns (dataset, num_dropped)."""
|
||||||
|
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||||
|
|
||||||
|
desc_kwargs = {}
|
||||||
|
if filter_kwargs:
|
||||||
|
action = (
|
||||||
|
"Checking Sequence Lengths"
|
||||||
|
if raise_on_long
|
||||||
|
else "Dropping Invalid Sequences"
|
||||||
|
)
|
||||||
|
desc_kwargs["desc"] = f"{action} (<{min_len} or >{max_len})"
|
||||||
|
|
||||||
|
dataset = dataset.filter(
|
||||||
|
functools.partial(
|
||||||
|
filter_sequences_by_length,
|
||||||
|
sequence_len=max_len,
|
||||||
|
min_sequence_len=min_len,
|
||||||
|
raise_on_drop=raise_on_long,
|
||||||
|
),
|
||||||
|
batched=True,
|
||||||
|
**filter_kwargs,
|
||||||
|
**desc_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
dropped = 0
|
||||||
|
if not raise_on_long and prior_len:
|
||||||
|
dropped = prior_len - len(dataset)
|
||||||
|
if dropped > 0:
|
||||||
|
LOG.info(
|
||||||
|
f"Dropped {dropped} sequences outside valid range "
|
||||||
|
f"([{min_len}, {max_len}])"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset, dropped
|
||||||
|
|
||||||
|
|
||||||
def handle_long_seq_in_dataset(
|
def handle_long_seq_in_dataset(
|
||||||
@@ -193,80 +327,25 @@ def handle_long_seq_in_dataset(
|
|||||||
'truncate' truncates them down to sequence_len
|
'truncate' truncates them down to sequence_len
|
||||||
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
|
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
|
||||||
"""
|
"""
|
||||||
if (
|
# Early returns for special cases
|
||||||
hasattr(dataset, "column_names")
|
if _should_skip_processing(dataset):
|
||||||
and dataset.column_names
|
|
||||||
and "input_ids" not in dataset.column_names
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
|
||||||
"expected for reward modeling."
|
|
||||||
)
|
|
||||||
return dataset
|
|
||||||
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
|
|
||||||
LOG.info(
|
|
||||||
"Dataset is streaming (IterableDataset), skipping long sequence handling"
|
|
||||||
)
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
|
|
||||||
drop_long = functools.partial(
|
_log_dataset_stats(dataset)
|
||||||
drop_long_seq,
|
|
||||||
sequence_len=sequence_len,
|
|
||||||
min_sequence_len=cfg.min_sample_len,
|
|
||||||
raise_on_drop=excess_length_strategy == "raise",
|
|
||||||
)
|
|
||||||
|
|
||||||
with contextlib.suppress(AttributeError):
|
# Setup kwargs
|
||||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
filter_kwargs = _build_filter_kwargs(dataset, cfg)
|
||||||
min_input_len = np.min(ds_lengths)
|
|
||||||
LOG.info(f"min_input_len: {min_input_len}")
|
|
||||||
max_input_len = np.max(ds_lengths)
|
|
||||||
LOG.info(f"max_input_len: {max_input_len}")
|
|
||||||
|
|
||||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
|
||||||
|
|
||||||
filter_map_kwargs = {}
|
|
||||||
if not isinstance(dataset, IterableDataset):
|
|
||||||
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
|
|
||||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
|
||||||
|
|
||||||
drop_long_kwargs = {}
|
|
||||||
if filter_map_kwargs:
|
|
||||||
action = (
|
|
||||||
"Checking Sequence Lengths"
|
|
||||||
if excess_length_strategy == "raise"
|
|
||||||
else "Dropping Long Sequences"
|
|
||||||
)
|
|
||||||
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
|
|
||||||
|
|
||||||
|
# Handle sequences based on strategy
|
||||||
if excess_length_strategy == "truncate":
|
if excess_length_strategy == "truncate":
|
||||||
process_fn = functools.partial(
|
dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs)
|
||||||
truncate_long_seq,
|
dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs)
|
||||||
sequence_len=sequence_len,
|
|
||||||
min_sequence_len=cfg.min_sample_len,
|
|
||||||
)
|
|
||||||
drop_long_kwargs["desc"] = (
|
|
||||||
f"Truncating/Filtering Sequences (target_len={sequence_len})"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
process_fn = drop_long
|
raise_on_long = excess_length_strategy == "raise"
|
||||||
|
dataset, _ = _drop_outside_range(
|
||||||
dataset = dataset.filter(
|
dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs
|
||||||
process_fn,
|
)
|
||||||
batched=True,
|
|
||||||
**filter_map_kwargs,
|
|
||||||
**drop_long_kwargs,
|
|
||||||
)
|
|
||||||
if prior_len:
|
|
||||||
dropped = prior_len - len(dataset)
|
|
||||||
if dropped:
|
|
||||||
action = (
|
|
||||||
"truncated/filtered"
|
|
||||||
if excess_length_strategy == "truncate"
|
|
||||||
else "dropped"
|
|
||||||
)
|
|
||||||
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -205,10 +205,13 @@ def add_length(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
def filter_sequences_by_length(
|
||||||
|
sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Drop samples whose sequence length is either too long (> sequence_len)
|
Filter sequences outside valid length range [min_sequence_len, sequence_len].
|
||||||
or too short (< min_sequence_len).
|
|
||||||
|
Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len).
|
||||||
|
|
||||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
|
|
||||||
@@ -383,10 +386,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
def process_pretraining_datasets_for_packing(
|
def process_pretraining_datasets_for_packing(
|
||||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
||||||
):
|
):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len)
|
||||||
|
|
||||||
train_dataset = train_dataset.filter(
|
train_dataset = train_dataset.filter(
|
||||||
drop_long,
|
drop_outside_range,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
load_from_cache_file=False,
|
load_from_cache_file=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import unittest
|
|||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import encode_streaming, md5
|
from axolotl.utils.data import encode_streaming, md5
|
||||||
from axolotl.utils.trainer import drop_long_seq
|
from axolotl.utils.trainer import filter_sequences_by_length
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
@@ -70,17 +70,19 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
# -- single sequence --
|
# -- single sequence --
|
||||||
# This should work
|
# This should work
|
||||||
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
|
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
|
||||||
drop_long_seq(data, 32, raise_on_drop=True)
|
filter_sequences_by_length(data, 32, raise_on_drop=True)
|
||||||
|
|
||||||
# This should return True, since data fits
|
# This should return True, since data fits
|
||||||
dropped = drop_long_seq(data, 32)
|
dropped = filter_sequences_by_length(data, 32)
|
||||||
self.assertTrue(dropped)
|
self.assertTrue(dropped)
|
||||||
|
|
||||||
# This should raise
|
# This should raise
|
||||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
self.assertRaises(
|
||||||
|
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
|
||||||
|
)
|
||||||
|
|
||||||
# This should return False, since data doesn't fit
|
# This should return False, since data doesn't fit
|
||||||
dropped = drop_long_seq(data, 15)
|
dropped = filter_sequences_by_length(data, 15)
|
||||||
self.assertFalse(dropped)
|
self.assertFalse(dropped)
|
||||||
|
|
||||||
# -- batch sequence --
|
# -- batch sequence --
|
||||||
@@ -91,13 +93,15 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
drop_long_seq(data, 32, raise_on_drop=True)
|
filter_sequences_by_length(data, 32, raise_on_drop=True)
|
||||||
|
|
||||||
# This should raise
|
# This should raise
|
||||||
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
self.assertRaises(
|
||||||
|
ValueError, filter_sequences_by_length, data, 15, raise_on_drop=True
|
||||||
|
)
|
||||||
|
|
||||||
# This should keep the first but drop the second entry
|
# This should keep the first but drop the second entry
|
||||||
dropped = drop_long_seq(data, 15)
|
dropped = filter_sequences_by_length(data, 15)
|
||||||
self.assertEqual(dropped, [True, False])
|
self.assertEqual(dropped, [True, False])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
545
tests/utils/data/test_utils.py
Normal file
545
tests/utils/data/test_utils.py
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for data utility functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from axolotl.utils.data.utils import handle_long_seq_in_dataset
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleLongSeqInDataset(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test class for handle_long_seq_in_dataset function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_drop_strategy_removes_long_sequences(self):
|
||||||
|
"""Test that 'drop' strategy removes sequences longer than sequence_len"""
|
||||||
|
# Create dataset with mixed length sequences
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3], # length 3 - keep
|
||||||
|
[1, 2, 3, 4, 5], # length 5 - keep
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - drop
|
||||||
|
[1, 2], # length 2 - keep
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should have dropped the sequence with length 11
|
||||||
|
self.assertEqual(len(result), 3)
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||||
|
self.assertEqual(len(result[1]["input_ids"]), 5)
|
||||||
|
self.assertEqual(len(result[2]["input_ids"]), 2)
|
||||||
|
|
||||||
|
def test_drop_strategy_is_default(self):
|
||||||
|
"""Test that 'drop' is the default strategy when not specified"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3],
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should drop
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should have dropped the long sequence
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
|
||||||
|
def test_truncate_strategy_truncates_long_sequences(self):
|
||||||
|
"""Test that 'truncate' strategy truncates sequences to sequence_len"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3], # length 3 - keep as is
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
], # length 12 - truncate to 10
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "truncate",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should have 2 samples
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
# First sample unchanged
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||||
|
# Second sample truncated to 10
|
||||||
|
self.assertEqual(len(result[1]["input_ids"]), 10)
|
||||||
|
self.assertEqual(result[1]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||||
|
|
||||||
|
def test_truncate_strategy_truncates_all_auxiliary_fields(self):
|
||||||
|
"""Test that truncation applies to all auxiliary fields consistently"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||||
|
],
|
||||||
|
"attention_mask": [
|
||||||
|
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
|
],
|
||||||
|
"labels": [
|
||||||
|
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||||
|
],
|
||||||
|
"position_ids": [
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "truncate",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# All fields should be truncated to 10
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||||
|
self.assertEqual(len(result[0]["attention_mask"]), 10)
|
||||||
|
self.assertEqual(len(result[0]["labels"]), 10)
|
||||||
|
self.assertEqual(len(result[0]["position_ids"]), 10)
|
||||||
|
|
||||||
|
# Verify content is correct
|
||||||
|
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||||
|
self.assertEqual(result[0]["attention_mask"], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
|
||||||
|
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||||
|
self.assertEqual(result[0]["position_ids"], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||||
|
|
||||||
|
def test_raise_strategy_raises_on_long_sequences(self):
|
||||||
|
"""Test that 'raise' strategy raises ValueError when encountering long sequences"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3],
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 - should raise
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "raise",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
def test_min_sequence_len_filters_short_sequences(self):
|
||||||
|
"""Test that sequences shorter than min_sample_len are filtered out"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1], # length 1 - drop (< min_sample_len=3)
|
||||||
|
[1, 2], # length 2 - drop
|
||||||
|
[1, 2, 3], # length 3 - keep
|
||||||
|
[1, 2, 3, 4, 5], # length 5 - keep
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
"min_sample_len": 3,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should only keep sequences with length >= 3
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||||
|
self.assertEqual(len(result[1]["input_ids"]), 5)
|
||||||
|
|
||||||
|
def test_dataset_without_input_ids_column(self):
|
||||||
|
"""Test that datasets without 'input_ids' column are returned unchanged"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"chosen": [1, 2, 3],
|
||||||
|
"rejected": [4, 5, 6],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Dataset should be unchanged
|
||||||
|
self.assertEqual(len(result), len(dataset))
|
||||||
|
self.assertListEqual(list(result.column_names), ["chosen", "rejected"])
|
||||||
|
|
||||||
|
def test_truncate_filters_short_before_truncating(self):
|
||||||
|
"""Test that truncate strategy filters short sequences before truncating long ones
|
||||||
|
|
||||||
|
This is important for efficiency - we should not waste time truncating
|
||||||
|
sequences that will be filtered out anyway.
|
||||||
|
"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1], # length 1 - filter out first
|
||||||
|
[1, 2, 3], # length 3 - keep, no truncation needed
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
], # length 12 - keep and truncate
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "truncate",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should have filtered out the first (short) sequence
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
# Second sample unchanged
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||||
|
# Third sample truncated to 10
|
||||||
|
self.assertEqual(len(result[1]["input_ids"]), 10)
|
||||||
|
|
||||||
|
def test_case_insensitive_strategy(self):
|
||||||
|
"""Test that excess_length_strategy is case-insensitive"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "TRUNCATE", # uppercase
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should still truncate
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||||
|
|
||||||
|
def test_raise_strategy_silently_drops_short_sequences(self):
|
||||||
|
"""Test that 'raise' strategy drops short sequences without raising"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1], # length 1 - too short, should be dropped silently
|
||||||
|
[1, 2, 3, 4, 5], # length 5 - keep
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "raise",
|
||||||
|
"min_sample_len": 3,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should NOT raise, just silently drop the short sequence
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 5)
|
||||||
|
|
||||||
|
def test_drop_boundary_sequence_equal_to_sequence_len(self):
|
||||||
|
"""Test that drop strategy keeps sequences with length exactly equal to sequence_len"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # length 11 > sequence_len
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Exactly equal should be kept, one over should be dropped
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||||
|
|
||||||
|
def test_truncate_boundary_sequence_equal_to_sequence_len(self):
|
||||||
|
"""Test that truncate strategy leaves sequences with length exactly equal to sequence_len unchanged"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], # length 10 == sequence_len
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "truncate",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should be unchanged - not truncated
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||||
|
|
||||||
|
def test_empty_dataset(self):
|
||||||
|
"""Test that an empty dataset is handled gracefully"""
|
||||||
|
dataset = Dataset.from_dict({"input_ids": []})
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
||||||
|
|
||||||
|
def test_all_sequences_dropped_returns_empty_dataset(self):
|
||||||
|
"""Test that dropping all sequences results in an empty dataset"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1], # too short
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # too long
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
"min_sample_len": 5,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 0)
|
||||||
|
|
||||||
|
def test_iterable_dataset_skips_processing(self):
|
||||||
|
"""Test that streaming datasets (column_names is None) are returned unchanged.
|
||||||
|
|
||||||
|
The skip check in _should_skip_processing triggers when column_names is
|
||||||
|
None, which happens with true streaming datasets loaded via
|
||||||
|
load_dataset(..., streaming=True).
|
||||||
|
"""
|
||||||
|
mock_dataset = MagicMock()
|
||||||
|
mock_dataset.column_names = None
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(mock_dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should be returned unchanged (same object)
|
||||||
|
self.assertIs(result, mock_dataset)
|
||||||
|
|
||||||
|
def test_truncate_with_partial_auxiliary_fields(self):
|
||||||
|
"""Test truncation when only some auxiliary fields are present"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||||
|
],
|
||||||
|
"labels": [
|
||||||
|
[-100, -100, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||||
|
],
|
||||||
|
# No attention_mask or position_ids
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "truncate",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 10)
|
||||||
|
self.assertEqual(len(result[0]["labels"]), 10)
|
||||||
|
self.assertEqual(result[0]["input_ids"], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||||
|
self.assertEqual(result[0]["labels"], [-100, -100, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||||
|
# Confirm no extra columns were introduced
|
||||||
|
self.assertListEqual(sorted(result.column_names), ["input_ids", "labels"])
|
||||||
|
|
||||||
|
def test_min_sample_len_defaults_to_two_when_not_set(self):
|
||||||
|
"""Test that min_sample_len defaults to 2 when not specified in config"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1], # length 1 - should be dropped (< default 2)
|
||||||
|
[1, 2], # length 2 - should be kept (>= default 2)
|
||||||
|
[1, 2, 3], # length 3 - should be kept
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "drop",
|
||||||
|
# min_sample_len not set
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
self.assertEqual(len(result), 2)
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 2)
|
||||||
|
self.assertEqual(len(result[1]["input_ids"]), 3)
|
||||||
|
|
||||||
|
def test_invalid_strategy_falls_through_to_drop(self):
|
||||||
|
"""Test that an unrecognized strategy value falls through to drop behavior"""
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3], # keep
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
], # length 11 - should be dropped
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"excess_length_strategy": "not_a_real_strategy",
|
||||||
|
"min_sample_len": 2,
|
||||||
|
"dataset_num_proc": None,
|
||||||
|
"is_preprocess": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handle_long_seq_in_dataset(dataset, sequence_len=10, cfg=cfg)
|
||||||
|
|
||||||
|
# Should behave like 'drop'
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(len(result[0]["input_ids"]), 3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user