Fix: excess_length_strategy truncation method (#3401)
* Add test cases to verify that the problem exists in the underlying * Update the handle_long_sequences function to correctly use Map instead of filter for the truncation strategy. Also remove the minimal length filtering from the truncate_long_samples function, and run it separately and before. * fix: refactor and add test truncate for non-input id fields * fix: refactor long seq handling fn * fix: refactor duplicate fn and simplify route * add additional tests and make them work on mac * handle logging exception on empty datasets --------- Co-authored-by: 2ndset bot <bot@2ndset.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -15,7 +15,7 @@ from datasets import Dataset, IterableDataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
from axolotl.utils.trainer import filter_sequences_by_length
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
@@ -148,22 +148,33 @@ def deduplicate_and_log_datasets(
|
||||
return dataset, other_dataset
|
||||
|
||||
|
||||
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
def keep_min_len(sample, min_sequence_len=2):
|
||||
"""
|
||||
Truncate samples whose sequence length is too long (> sequence_len)
|
||||
or drop those too short (< min_sequence_len).
|
||||
Batched filter function that keeps only samples with sequence length >= min_sequence_len.
|
||||
Returns a list of booleans indicating which samples to keep.
|
||||
"""
|
||||
min_sequence_len = min_sequence_len or 2
|
||||
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
results = []
|
||||
for seq in input_ids:
|
||||
results.append(len(seq) >= min_sequence_len)
|
||||
return results
|
||||
|
||||
|
||||
def truncate_long_seq(sample, sequence_len=2048):
|
||||
"""
|
||||
Truncate samples whose sequence length is too long (> sequence_len).
|
||||
Modifies the sample in-place and returns the modified sample.
|
||||
"""
|
||||
input_ids = sample["input_ids"]
|
||||
|
||||
# Batched (input_ids is a list of lists)
|
||||
for i, seq in enumerate(input_ids):
|
||||
length = len(seq)
|
||||
if length < min_sequence_len:
|
||||
results.append(False)
|
||||
elif length > sequence_len:
|
||||
if length > sequence_len:
|
||||
sample["input_ids"][i] = seq[:sequence_len]
|
||||
if "attention_mask" in sample:
|
||||
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||
@@ -171,10 +182,133 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||
if "position_ids" in sample:
|
||||
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||
results.append(True)
|
||||
else:
|
||||
results.append(True)
|
||||
return results
|
||||
return sample
|
||||
|
||||
|
||||
def _should_skip_processing(dataset: Dataset) -> bool:
|
||||
"""Check if dataset should skip long sequence handling."""
|
||||
if (
|
||||
hasattr(dataset, "column_names")
|
||||
and dataset.column_names
|
||||
and "input_ids" not in dataset.column_names
|
||||
):
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||
"expected for reward modeling."
|
||||
)
|
||||
return True
|
||||
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
|
||||
LOG.info(
|
||||
"Dataset is streaming (IterableDataset), skipping long sequence handling"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _log_dataset_stats(dataset: Dataset) -> None:
|
||||
"""Log min/max sequence lengths for debugging."""
|
||||
with contextlib.suppress(AttributeError, ValueError):
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
LOG.info(f"min_input_len: {np.min(ds_lengths)}")
|
||||
LOG.info(f"max_input_len: {np.max(ds_lengths)}")
|
||||
|
||||
|
||||
def _build_filter_kwargs(dataset: Dataset, cfg: DictDefault) -> dict:
|
||||
"""Build kwargs for dataset filter/map operations."""
|
||||
kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
kwargs["num_proc"] = cfg.dataset_num_proc
|
||||
kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
return kwargs
|
||||
|
||||
|
||||
def _filter_short_sequences(
|
||||
dataset: Dataset, min_len: int, filter_kwargs: dict
|
||||
) -> tuple[Dataset, int]:
|
||||
"""Filter out sequences shorter than min_len. Returns (dataset, num_dropped)."""
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
|
||||
desc_kwargs = {}
|
||||
if filter_kwargs:
|
||||
desc_kwargs["desc"] = f"Filtering Short Sequences (<{min_len})"
|
||||
|
||||
dataset = dataset.filter(
|
||||
functools.partial(keep_min_len, min_sequence_len=min_len),
|
||||
batched=True,
|
||||
**filter_kwargs,
|
||||
**desc_kwargs,
|
||||
)
|
||||
|
||||
dropped = 0
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped > 0:
|
||||
LOG.info(f"Dropped {dropped} short sequences (<{min_len} tokens)")
|
||||
|
||||
return dataset, dropped
|
||||
|
||||
|
||||
def _truncate_long_sequences(
|
||||
dataset: Dataset, max_len: int, map_kwargs: dict
|
||||
) -> Dataset:
|
||||
"""Truncate sequences longer than max_len."""
|
||||
desc_kwargs = {}
|
||||
if map_kwargs:
|
||||
desc_kwargs["desc"] = f"Truncating Sequences (target_len={max_len})"
|
||||
|
||||
dataset = dataset.map(
|
||||
functools.partial(truncate_long_seq, sequence_len=max_len),
|
||||
batched=True,
|
||||
**map_kwargs,
|
||||
**desc_kwargs,
|
||||
)
|
||||
LOG.info(f"Truncated long sequences to max length {max_len}")
|
||||
return dataset
|
||||
|
||||
|
||||
def _drop_outside_range(
|
||||
dataset: Dataset,
|
||||
max_len: int,
|
||||
min_len: int,
|
||||
raise_on_long: bool,
|
||||
filter_kwargs: dict,
|
||||
) -> tuple[Dataset, int]:
|
||||
"""Drop sequences outside valid length range [min_len, max_len].
|
||||
|
||||
Returns (dataset, num_dropped)."""
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
|
||||
desc_kwargs = {}
|
||||
if filter_kwargs:
|
||||
action = (
|
||||
"Checking Sequence Lengths"
|
||||
if raise_on_long
|
||||
else "Dropping Invalid Sequences"
|
||||
)
|
||||
desc_kwargs["desc"] = f"{action} (<{min_len} or >{max_len})"
|
||||
|
||||
dataset = dataset.filter(
|
||||
functools.partial(
|
||||
filter_sequences_by_length,
|
||||
sequence_len=max_len,
|
||||
min_sequence_len=min_len,
|
||||
raise_on_drop=raise_on_long,
|
||||
),
|
||||
batched=True,
|
||||
**filter_kwargs,
|
||||
**desc_kwargs,
|
||||
)
|
||||
|
||||
dropped = 0
|
||||
if not raise_on_long and prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped > 0:
|
||||
LOG.info(
|
||||
f"Dropped {dropped} sequences outside valid range "
|
||||
f"([{min_len}, {max_len}])"
|
||||
)
|
||||
|
||||
return dataset, dropped
|
||||
|
||||
|
||||
def handle_long_seq_in_dataset(
|
||||
@@ -193,80 +327,25 @@ def handle_long_seq_in_dataset(
|
||||
'truncate' truncates them down to sequence_len
|
||||
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
|
||||
"""
|
||||
if (
|
||||
hasattr(dataset, "column_names")
|
||||
and dataset.column_names
|
||||
and "input_ids" not in dataset.column_names
|
||||
):
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||
"expected for reward modeling."
|
||||
)
|
||||
return dataset
|
||||
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
|
||||
LOG.info(
|
||||
"Dataset is streaming (IterableDataset), skipping long sequence handling"
|
||||
)
|
||||
# Early returns for special cases
|
||||
if _should_skip_processing(dataset):
|
||||
return dataset
|
||||
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
|
||||
drop_long = functools.partial(
|
||||
drop_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
raise_on_drop=excess_length_strategy == "raise",
|
||||
)
|
||||
_log_dataset_stats(dataset)
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||
min_input_len = np.min(ds_lengths)
|
||||
LOG.info(f"min_input_len: {min_input_len}")
|
||||
max_input_len = np.max(ds_lengths)
|
||||
LOG.info(f"max_input_len: {max_input_len}")
|
||||
|
||||
prior_len = len(dataset) if hasattr(dataset, "__len__") else None
|
||||
|
||||
filter_map_kwargs = {}
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
|
||||
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
action = (
|
||||
"Checking Sequence Lengths"
|
||||
if excess_length_strategy == "raise"
|
||||
else "Dropping Long Sequences"
|
||||
)
|
||||
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
|
||||
# Setup kwargs
|
||||
filter_kwargs = _build_filter_kwargs(dataset, cfg)
|
||||
|
||||
# Handle sequences based on strategy
|
||||
if excess_length_strategy == "truncate":
|
||||
process_fn = functools.partial(
|
||||
truncate_long_seq,
|
||||
sequence_len=sequence_len,
|
||||
min_sequence_len=cfg.min_sample_len,
|
||||
)
|
||||
drop_long_kwargs["desc"] = (
|
||||
f"Truncating/Filtering Sequences (target_len={sequence_len})"
|
||||
)
|
||||
dataset, _ = _filter_short_sequences(dataset, cfg.min_sample_len, filter_kwargs)
|
||||
dataset = _truncate_long_sequences(dataset, sequence_len, filter_kwargs)
|
||||
else:
|
||||
process_fn = drop_long
|
||||
|
||||
dataset = dataset.filter(
|
||||
process_fn,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(dataset)
|
||||
if dropped:
|
||||
action = (
|
||||
"truncated/filtered"
|
||||
if excess_length_strategy == "truncate"
|
||||
else "dropped"
|
||||
)
|
||||
LOG.warning(f"{action.title()} {dropped} samples from dataset")
|
||||
raise_on_long = excess_length_strategy == "raise"
|
||||
dataset, _ = _drop_outside_range(
|
||||
dataset, sequence_len, cfg.min_sample_len, raise_on_long, filter_kwargs
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -205,10 +205,13 @@ def add_length(sample):
|
||||
return sample
|
||||
|
||||
|
||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
||||
def filter_sequences_by_length(
|
||||
sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False
|
||||
):
|
||||
"""
|
||||
Drop samples whose sequence length is either too long (> sequence_len)
|
||||
or too short (< min_sequence_len).
|
||||
Filter sequences outside valid length range [min_sequence_len, sequence_len].
|
||||
|
||||
Drops samples that are either too short (< min_sequence_len) or too long (> sequence_len).
|
||||
|
||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||
|
||||
@@ -383,10 +386,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
def process_pretraining_datasets_for_packing(
|
||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
||||
):
|
||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
||||
drop_outside_range = partial(filter_sequences_by_length, sequence_len=sequence_len)
|
||||
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
drop_outside_range,
|
||||
desc="Dropping Long Sequences",
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user